Repository: mwaskom/seaborn Branch: master Commit: 32088bbc3adc Files: 310 Total size: 2.6 MB Directory structure: gitextract_2fyw04_k/ ├── .github/ │ ├── CONTRIBUTING.md │ ├── dependabot.yml │ └── workflows/ │ └── ci.yaml ├── .gitignore ├── .pre-commit-config.yaml ├── CITATION.cff ├── LICENSE.md ├── Makefile ├── README.md ├── SECURITY.md ├── ci/ │ ├── cache_datasets.py │ ├── check_gallery.py │ ├── deps_pinned.txt │ └── getmsfonts.sh ├── doc/ │ ├── .gitignore │ ├── Makefile │ ├── README.md │ ├── _docstrings/ │ │ ├── FacetGrid.ipynb │ │ ├── JointGrid.ipynb │ │ ├── Makefile │ │ ├── PairGrid.ipynb │ │ ├── axes_style.ipynb │ │ ├── barplot.ipynb │ │ ├── blend_palette.ipynb │ │ ├── boxenplot.ipynb │ │ ├── boxplot.ipynb │ │ ├── catplot.ipynb │ │ ├── clustermap.ipynb │ │ ├── color_palette.ipynb │ │ ├── countplot.ipynb │ │ ├── cubehelix_palette.ipynb │ │ ├── dark_palette.ipynb │ │ ├── displot.ipynb │ │ ├── diverging_palette.ipynb │ │ ├── ecdfplot.ipynb │ │ ├── heatmap.ipynb │ │ ├── histplot.ipynb │ │ ├── hls_palette.ipynb │ │ ├── husl_palette.ipynb │ │ ├── jointplot.ipynb │ │ ├── kdeplot.ipynb │ │ ├── light_palette.ipynb │ │ ├── lineplot.ipynb │ │ ├── lmplot.ipynb │ │ ├── move_legend.ipynb │ │ ├── mpl_palette.ipynb │ │ ├── objects.Agg.ipynb │ │ ├── objects.Area.ipynb │ │ ├── objects.Band.ipynb │ │ ├── objects.Bar.ipynb │ │ ├── objects.Bars.ipynb │ │ ├── objects.Count.ipynb │ │ ├── objects.Dash.ipynb │ │ ├── objects.Dodge.ipynb │ │ ├── objects.Dot.ipynb │ │ ├── objects.Dots.ipynb │ │ ├── objects.Est.ipynb │ │ ├── objects.Hist.ipynb │ │ ├── objects.Jitter.ipynb │ │ ├── objects.KDE.ipynb │ │ ├── objects.Line.ipynb │ │ ├── objects.Lines.ipynb │ │ ├── objects.Norm.ipynb │ │ ├── objects.Path.ipynb │ │ ├── objects.Paths.ipynb │ │ ├── objects.Perc.ipynb │ │ ├── objects.Plot.add.ipynb │ │ ├── objects.Plot.config.ipynb │ │ ├── objects.Plot.facet.ipynb │ │ ├── objects.Plot.label.ipynb │ │ ├── objects.Plot.layout.ipynb │ │ ├── objects.Plot.limit.ipynb │ │ ├── objects.Plot.on.ipynb │ │ ├── objects.Plot.pair.ipynb │ │ ├── objects.Plot.scale.ipynb │ │ ├── objects.Plot.share.ipynb │ │ ├── objects.Plot.theme.ipynb │ │ ├── objects.Range.ipynb │ │ ├── objects.Shift.ipynb │ │ ├── objects.Stack.ipynb │ │ ├── objects.Text.ipynb │ │ ├── pairplot.ipynb │ │ ├── plotting_context.ipynb │ │ ├── pointplot.ipynb │ │ ├── regplot.ipynb │ │ ├── relplot.ipynb │ │ ├── residplot.ipynb │ │ ├── rugplot.ipynb │ │ ├── scatterplot.ipynb │ │ ├── set_context.ipynb │ │ ├── set_style.ipynb │ │ ├── set_theme.ipynb │ │ ├── stripplot.ipynb │ │ ├── swarmplot.ipynb │ │ └── violinplot.ipynb │ ├── _static/ │ │ ├── copybutton.js │ │ └── css/ │ │ └── custom.css │ ├── _templates/ │ │ ├── autosummary/ │ │ │ ├── base.rst │ │ │ ├── class.rst │ │ │ ├── object.rst │ │ │ ├── plot.rst │ │ │ └── scale.rst │ │ ├── layout.html │ │ └── version.html │ ├── _tutorial/ │ │ ├── Makefile │ │ ├── aesthetics.ipynb │ │ ├── axis_grids.ipynb │ │ ├── categorical.ipynb │ │ ├── color_palettes.ipynb │ │ ├── data_structure.ipynb │ │ ├── distributions.ipynb │ │ ├── error_bars.ipynb │ │ ├── function_overview.ipynb │ │ ├── introduction.ipynb │ │ ├── objects_interface.ipynb │ │ ├── properties.ipynb │ │ ├── regression.ipynb │ │ └── relational.ipynb │ ├── api.rst │ ├── citing.rst │ ├── conf.py │ ├── example_thumbs/ │ │ └── .gitkeep │ ├── faq.rst │ ├── index.rst │ ├── installing.rst │ ├── make.bat │ ├── matplotlibrc │ ├── sphinxext/ │ │ ├── gallery_generator.py │ │ └── tutorial_builder.py │ ├── tools/ │ │ ├── extract_examples.py │ │ ├── generate_logos.py │ │ ├── nb_to_doc.py │ │ └── set_nb_kernels.py │ ├── tutorial.yaml │ └── whatsnew/ │ ├── index.rst │ ├── v0.10.0.rst │ ├── v0.10.1.rst │ ├── v0.11.0.rst │ ├── v0.11.1.rst │ ├── v0.11.2.rst │ ├── v0.12.0.rst │ ├── v0.12.1.rst │ ├── v0.12.2.rst │ ├── v0.13.0.rst │ ├── v0.13.1.rst │ ├── v0.13.2.rst │ ├── v0.2.0.rst │ ├── v0.2.1.rst │ ├── v0.3.0.rst │ ├── v0.3.1.rst │ ├── v0.4.0.rst │ ├── v0.5.0.rst │ ├── v0.5.1.rst │ ├── v0.6.0.rst │ ├── v0.7.0.rst │ ├── v0.7.1.rst │ ├── v0.8.0.rst │ ├── v0.8.1.rst │ ├── v0.9.0.rst │ └── v0.9.1.rst ├── examples/ │ ├── .gitignore │ ├── anscombes_quartet.py │ ├── different_scatter_variables.py │ ├── errorband_lineplots.py │ ├── faceted_histogram.py │ ├── faceted_lineplot.py │ ├── grouped_barplot.py │ ├── grouped_boxplot.py │ ├── grouped_violinplots.py │ ├── heat_scatter.py │ ├── hexbin_marginals.py │ ├── histogram_stacked.py │ ├── horizontal_boxplot.py │ ├── jitter_stripplot.py │ ├── joint_histogram.py │ ├── joint_kde.py │ ├── kde_ridgeplot.py │ ├── large_distributions.py │ ├── layered_bivariate_plot.py │ ├── logistic_regression.py │ ├── many_facets.py │ ├── many_pairwise_correlations.py │ ├── marginal_ticks.py │ ├── multiple_bivariate_kde.py │ ├── multiple_conditional_kde.py │ ├── multiple_ecdf.py │ ├── multiple_regression.py │ ├── pair_grid_with_kde.py │ ├── paired_pointplots.py │ ├── pairgrid_dotplot.py │ ├── palette_choices.py │ ├── palette_generation.py │ ├── part_whole_bars.py │ ├── pointplot_anova.py │ ├── radial_facets.py │ ├── regression_marginals.py │ ├── residplot.py │ ├── scatter_bubbles.py │ ├── scatterplot_categorical.py │ ├── scatterplot_matrix.py │ ├── scatterplot_sizes.py │ ├── simple_violinplots.py │ ├── smooth_bivariate_kde.py │ ├── spreadsheet_heatmap.py │ ├── strip_regplot.py │ ├── structured_heatmap.py │ ├── three_variable_histogram.py │ ├── timeseries_facets.py │ ├── wide_data_lineplot.py │ └── wide_form_violinplot.py ├── licences/ │ ├── APPDIRS_LICENSE │ ├── HUSL_LICENSE │ ├── NUMPYDOC_LICENSE │ ├── PACKAGING_LICENSE │ └── SCIPY_LICENSE ├── pyproject.toml ├── seaborn/ │ ├── __init__.py │ ├── _base.py │ ├── _compat.py │ ├── _core/ │ │ ├── __init__.py │ │ ├── data.py │ │ ├── exceptions.py │ │ ├── groupby.py │ │ ├── moves.py │ │ ├── plot.py │ │ ├── properties.py │ │ ├── rules.py │ │ ├── scales.py │ │ ├── subplots.py │ │ └── typing.py │ ├── _docstrings.py │ ├── _marks/ │ │ ├── __init__.py │ │ ├── area.py │ │ ├── bar.py │ │ ├── base.py │ │ ├── dot.py │ │ ├── line.py │ │ └── text.py │ ├── _statistics.py │ ├── _stats/ │ │ ├── __init__.py │ │ ├── aggregation.py │ │ ├── base.py │ │ ├── counting.py │ │ ├── density.py │ │ ├── order.py │ │ └── regression.py │ ├── _testing.py │ ├── algorithms.py │ ├── axisgrid.py │ ├── categorical.py │ ├── cm.py │ ├── colors/ │ │ ├── __init__.py │ │ ├── crayons.py │ │ └── xkcd_rgb.py │ ├── distributions.py │ ├── external/ │ │ ├── __init__.py │ │ ├── appdirs.py │ │ ├── docscrape.py │ │ ├── husl.py │ │ ├── kde.py │ │ └── version.py │ ├── matrix.py │ ├── miscplot.py │ ├── objects.py │ ├── palettes.py │ ├── rcmod.py │ ├── regression.py │ ├── relational.py │ ├── utils.py │ └── widgets.py ├── setup.cfg └── tests/ ├── __init__.py ├── _core/ │ ├── __init__.py │ ├── test_data.py │ ├── test_groupby.py │ ├── test_moves.py │ ├── test_plot.py │ ├── test_properties.py │ ├── test_rules.py │ ├── test_scales.py │ └── test_subplots.py ├── _marks/ │ ├── __init__.py │ ├── test_area.py │ ├── test_bar.py │ ├── test_base.py │ ├── test_dot.py │ ├── test_line.py │ └── test_text.py ├── _stats/ │ ├── __init__.py │ ├── test_aggregation.py │ ├── test_counting.py │ ├── test_density.py │ ├── test_order.py │ └── test_regression.py ├── conftest.py ├── test_algorithms.py ├── test_axisgrid.py ├── test_base.py ├── test_categorical.py ├── test_distributions.py ├── test_docstrings.py ├── test_matrix.py ├── test_miscplot.py ├── test_objects.py ├── test_palettes.py ├── test_rcmod.py ├── test_regression.py ├── test_relational.py ├── test_statistics.py └── test_utils.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CONTRIBUTING.md ================================================ Contributing to seaborn ======================= General support --------------- General support questions ("how do I do X?") are most at home on [StackOverflow](https://stackoverflow.com/), which has a larger audience of people who will see your post and may be able to offer assistance. Your chance of getting a quick answer will be higher if you include runnable code, a precise statement of what you are hoping to achieve, and a clear explanation of the problems that you have encountered. Reporting bugs -------------- If you think you've encountered a bug in seaborn, please report it on the [Github issue tracker](https://github.com/mwaskom/seaborn/issues/new). To be useful, bug reports *must* include the following information: - A reproducible code example that demonstrates the problem - The output that you are seeing (an image of a plot, or the error message) - A clear explanation of why you think something is wrong - The specific versions of seaborn and matplotlib that you are working with Bug reports are easiest to address if they can be demonstrated using one of the example datasets from the seaborn docs (i.e. with `seaborn.load_dataset`). Otherwise, it is preferable that your example generate synthetic data to reproduce the problem. If you can only demonstrate the issue with your actual dataset, you will need to share it, ideally as a csv (do not share data as a pickle file). If you've encountered an error, searching the specific text of the message before opening a new issue can often help you solve the problem quickly and avoid making a duplicate report. Because matplotlib handles the actual rendering, errors or incorrect outputs may be due to a problem in matplotlib rather than one in seaborn. It can save time if you try to reproduce the issue in an example that uses only matplotlib, so that you can report it in the right place. But it is alright to skip this step if it's not obvious how to do it. New features ------------ If you think there is a new feature that should be added to seaborn, you can open an issue to discuss it. But please be aware that current development efforts are mostly focused on standardizing the API and internals, and there may be relatively low enthusiasm for novel features that do not fit well into short- and medium-term development plans. ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: # Maintain dependencies for GitHub Actions - package-ecosystem: "github-actions" directory: "/" schedule: # Check for updates to GitHub Actions every week interval: "weekly" ================================================ FILE: .github/workflows/ci.yaml ================================================ name: CI on: push: branches: [master, v0.*] pull_request: branches: master schedule: - cron: '0 6 * * 1,4' # Each Monday and Thursday at 06:00 UTC workflow_dispatch: permissions: contents: read env: NB_KERNEL: python MPLBACKEND: Agg SEABORN_DATA: ${{ github.workspace }}/seaborn-data PYDEVD_DISABLE_FILE_VALIDATION: 1 jobs: build-docs: runs-on: ubuntu-latest steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Setup Python 3.11 uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 with: python-version: "3.11" - name: Install seaborn run: | pip install --upgrade pip pip install .[stats,docs] - name: Install pandoc run: | wget https://github.com/jgm/pandoc/releases/download/3.1.11/pandoc-3.1.11-1-amd64.deb sudo dpkg -i pandoc-3.1.11-1-amd64.deb - name: Cache datasets run: | git clone https://github.com/mwaskom/seaborn-data.git ls $SEABORN_DATA - name: Build docs env: SPHINXOPTS: -j `nproc` run: | cd doc make -j `nproc` notebooks make html run-tests: runs-on: ubuntu-latest strategy: matrix: python: ["3.10", "3.11", "3.12", "3.13", "3.14"] install: [full] deps: [latest] include: - python: "3.10" install: full deps: pinned - python: "3.13" install: light deps: latest steps: - uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Setup Python ${{ matrix.python }} uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 with: python-version: ${{ matrix.python }} allow-prereleases: true - name: Install seaborn run: | pip install --upgrade pip wheel if [[ ${{matrix.install}} == 'full' ]]; then EXTRAS=',stats'; fi if [[ ${{matrix.deps }} == 'pinned' ]]; then DEPS='-r ci/deps_pinned.txt'; fi pip install .[dev$EXTRAS] $DEPS - name: Run tests run: make test - name: Upload coverage uses: codecov/codecov-action@eaaf4bedf32dbdc6b720b63067d99c4d77d6047d # v3.1.4 if: ${{ success() }} lint: runs-on: ubuntu-latest strategy: fail-fast: false steps: - name: Checkout uses: actions/checkout@b4ffde65f46336ab88eb53be808477a3936bae11 # v4.1.1 - name: Setup Python uses: actions/setup-python@0a5c61591373683505ea898e09a3ea4f39ef2b9c # v5.0.0 - name: Install tools run: pip install mypy~=1.10.0 flake8 - name: Flake8 run: make lint - name: Type checking run: make typecheck ================================================ FILE: .gitignore ================================================ *.pyc *.sw* build/ .ipynb_checkpoints/ dist/ seaborn.egg-info/ .cache/ .coverage cover/ htmlcov/ .idea/ .vscode/ .pytest_cache/ .DS_Store notes/ notebooks/ ================================================ FILE: .pre-commit-config.yaml ================================================ repos: - repo: https://github.com/pre-commit/pre-commit-hooks rev: v4.3.0 hooks: - id: check-yaml - id: end-of-file-fixer - id: trailing-whitespace exclude: \.svg$ - repo: https://github.com/pycqa/flake8 rev: 5.0.4 hooks: - id: flake8 exclude: seaborn/(cm\.py|external/) types: [file, python] - repo: https://github.com/pre-commit/mirrors-mypy rev: v0.971 hooks: - id: mypy args: [--follow-imports=skip] files: seaborn/_(core|marks|stats)/ ================================================ FILE: CITATION.cff ================================================ cff-version: 1.2.0 message: "If seaborn is integral to a scientific publication, please cite the following paper:" preferred-citation: type: article authors: - family-names: "Waskom" given-names: "Michael Lawrence" orcid: "https://orcid.org/0000-0002-9817-6869" doi: "10.21105/joss.03021" journal: "Journal of Open Source Software" month: April title: "seaborn: statistical data visualization" issue: 6 volume: 60 year: 2021 url: "https://joss.theoj.org/papers/10.21105/joss.03021" ================================================ FILE: LICENSE.md ================================================ Copyright (c) 2012-2023, Michael L. Waskom All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: * Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. * Neither the name of the project nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: Makefile ================================================ export SHELL := /bin/bash test: pytest -n auto --cov=seaborn --cov=tests --cov-config=setup.cfg tests lint: flake8 seaborn/ tests/ typecheck: mypy --follow-imports=skip seaborn/_core seaborn/_marks seaborn/_stats ================================================ FILE: README.md ================================================
-------------------------------------- seaborn: statistical data visualization ======================================= [![PyPI Version](https://img.shields.io/pypi/v/seaborn.svg)](https://pypi.org/project/seaborn/) [![License](https://img.shields.io/pypi/l/seaborn.svg)](https://github.com/mwaskom/seaborn/blob/master/LICENSE.md) [![DOI](https://joss.theoj.org/papers/10.21105/joss.03021/status.svg)](https://doi.org/10.21105/joss.03021) [![Tests](https://github.com/mwaskom/seaborn/workflows/CI/badge.svg)](https://github.com/mwaskom/seaborn/actions) [![Code Coverage](https://codecov.io/gh/mwaskom/seaborn/branch/master/graph/badge.svg)](https://codecov.io/gh/mwaskom/seaborn) Seaborn is a Python visualization library based on matplotlib. It provides a high-level interface for drawing attractive statistical graphics. Documentation ------------- Online documentation is available at [seaborn.pydata.org](https://seaborn.pydata.org). The docs include a [tutorial](https://seaborn.pydata.org/tutorial.html), [example gallery](https://seaborn.pydata.org/examples/index.html), [API reference](https://seaborn.pydata.org/api.html), [FAQ](https://seaborn.pydata.org/faq), and other useful information. To build the documentation locally, please refer to [`doc/README.md`](doc/README.md). Dependencies ------------ Seaborn supports Python 3.8+. Installation requires [numpy](https://numpy.org/), [pandas](https://pandas.pydata.org/), and [matplotlib](https://matplotlib.org/). Some advanced statistical functionality requires [scipy](https://www.scipy.org/) and/or [statsmodels](https://www.statsmodels.org/). Installation ------------ The latest stable release (and required dependencies) can be installed from PyPI: pip install seaborn It is also possible to include optional statistical dependencies: pip install seaborn[stats] Seaborn can also be installed with conda: conda install seaborn Note that the main anaconda repository lags PyPI in adding new releases, but conda-forge (`-c conda-forge`) typically updates quickly. Citing ------ A paper describing seaborn has been published in the [Journal of Open Source Software](https://joss.theoj.org/papers/10.21105/joss.03021). The paper provides an introduction to the key features of the library, and it can be used as a citation if seaborn proves integral to a scientific publication. Testing ------- Testing seaborn requires installing additional dependencies; they can be installed with the `dev` extra (e.g., `pip install .[dev]`). To test the code, run `make test` in the source directory. This will exercise the unit tests (using [pytest](https://docs.pytest.org/)) and generate a coverage report. Code style is enforced with `flake8` using the settings in the [`setup.cfg`](./setup.cfg) file. Run `make lint` to check. Alternately, you can use `pre-commit` to automatically run lint checks on any files you are committing: just run `pre-commit install` to set it up, and then commit as usual going forward. Development ----------- Seaborn development takes place on Github: https://github.com/mwaskom/seaborn Please submit bugs that you encounter to the [issue tracker](https://github.com/mwaskom/seaborn/issues) with a reproducible example demonstrating the problem. Questions about usage are more at home on StackOverflow, where there is a [seaborn tag](https://stackoverflow.com/questions/tagged/seaborn). ================================================ FILE: SECURITY.md ================================================ # Security Policy If you have discovered a security vulnerability in this project, please report it privately. **Do not disclose it as a public issue.** This gives me time to work with you to fix the issue before public exposure, reducing the chance that the exploit will be used before a patch is released. You may submit the report by filling out [this form](https://github.com/mwaskom/seaborn/security/advisories/new). Please provide the following information in your report: - A description of the vulnerability and its impact - How to reproduce the issue This project is maintained by a single maintainer on a reasonable-effort basis. As such, I ask that you give me 90 days to work on a fix before public exposure. ================================================ FILE: ci/cache_datasets.py ================================================ """ Cache test datasets before running tests / building docs. Avoids race conditions that would arise from parallelization. """ import pathlib import re from seaborn import load_dataset path = pathlib.Path(".") py_files = path.rglob("*.py") ipynb_files = path.rglob("*.ipynb") datasets = [] for fname in py_files: with open(fname) as fid: datasets += re.findall(r"load_dataset\(['\"](\w+)['\"]", fid.read()) for p in ipynb_files: with p.open() as fid: datasets += re.findall(r"load_dataset\(\\['\"](\w+)\\['\"]", fid.read()) for name in sorted(set(datasets)): print(f"Caching {name}") load_dataset(name) ================================================ FILE: ci/check_gallery.py ================================================ """Execute the scripts that comprise the example gallery in the online docs.""" from glob import glob import matplotlib.pyplot as plt if __name__ == "__main__": fnames = sorted(glob("examples/*.py")) for fname in fnames: print(f"- {fname}") with open(fname) as fid: exec(fid.read()) plt.close("all") ================================================ FILE: ci/deps_pinned.txt ================================================ numpy~=1.26.0 pandas~=2.0.0 matplotlib~=3.8.0 scipy~=1.11.0 statsmodels~=0.14.0 pillow~=10.3.0 ================================================ FILE: ci/getmsfonts.sh ================================================ echo ttf-mscorefonts-installer msttcorefonts/accepted-mscorefonts-eula select true | debconf-set-selections apt-get install msttcorefonts -qq ================================================ FILE: doc/.gitignore ================================================ *_files/ _build/ generated/ examples/ example_thumbs/*.png docstrings/ tutorial/ tutorial/_images tutorial.rst ================================================ FILE: doc/Makefile ================================================ # Makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build PAPER = BUILDDIR = _build # Internal variables. PAPEROPT_a4 = -D latex_paper_size=a4 PAPEROPT_letter = -D latex_paper_size=letter ALLSPHINXOPTS = -d $(BUILDDIR)/doctrees $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . # the i18n builder cannot share the environment and doctrees with the others I18NSPHINXOPTS = $(PAPEROPT_$(PAPER)) $(SPHINXOPTS) . .PHONY: help clean html dirhtml singlehtml pickle json htmlhelp qthelp devhelp epub latex latexpdf text man changes linkcheck doctest gettext help: @echo "Please use \`make ' where is one of" @echo " clean to remove generated output" @echo " html to make standalone HTML files" @echo " notebooks to make the Jupyter notebook-based tutorials" @echo " dirhtml to make HTML files named index.html in directories" @echo " singlehtml to make a single large HTML file" @echo " pickle to make pickle files" @echo " json to make JSON files" @echo " htmlhelp to make HTML files and a HTML help project" @echo " qthelp to make HTML files and a qthelp project" @echo " devhelp to make HTML files and a Devhelp project" @echo " epub to make an epub" @echo " latex to make LaTeX files, you can set PAPER=a4 or PAPER=letter" @echo " latexpdf to make LaTeX files and run them through pdflatex" @echo " text to make text files" @echo " man to make manual pages" @echo " texinfo to make Texinfo files" @echo " info to make Texinfo files and run them through makeinfo" @echo " gettext to make PO message catalogs" @echo " changes to make an overview of all changed/added/deprecated items" @echo " linkcheck to check all external links for integrity" @echo " doctest to run all doctests embedded in the documentation (if enabled)" clean: -rm -rf $(BUILDDIR)/* -rm -rf examples/* -rm -rf example_thumbs/* -rm -rf generated/* -rm -rf tutorial.rst -$(MAKE) -C _docstrings clean -$(MAKE) -C _tutorial clean .PHONY: tutorials tutorials: @mkdir -p tutorial @$(MAKE) -C _tutorial .PHONY: docstrings docstrings: @mkdir -p docstrings @$(MAKE) -C _docstrings notebooks: tutorials docstrings html: $(SPHINXBUILD) -b html $(ALLSPHINXOPTS) $(BUILDDIR)/html @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/html." dirhtml: $(SPHINXBUILD) -b dirhtml $(ALLSPHINXOPTS) $(BUILDDIR)/dirhtml @echo @echo "Build finished. The HTML pages are in $(BUILDDIR)/dirhtml." singlehtml: $(SPHINXBUILD) -b singlehtml $(ALLSPHINXOPTS) $(BUILDDIR)/singlehtml @echo @echo "Build finished. The HTML page is in $(BUILDDIR)/singlehtml." pickle: $(SPHINXBUILD) -b pickle $(ALLSPHINXOPTS) $(BUILDDIR)/pickle @echo @echo "Build finished; now you can process the pickle files." json: $(SPHINXBUILD) -b json $(ALLSPHINXOPTS) $(BUILDDIR)/json @echo @echo "Build finished; now you can process the JSON files." htmlhelp: $(SPHINXBUILD) -b htmlhelp $(ALLSPHINXOPTS) $(BUILDDIR)/htmlhelp @echo @echo "Build finished; now you can run HTML Help Workshop with the" \ ".hhp project file in $(BUILDDIR)/htmlhelp." qthelp: $(SPHINXBUILD) -b qthelp $(ALLSPHINXOPTS) $(BUILDDIR)/qthelp @echo @echo "Build finished; now you can run "qcollectiongenerator" with the" \ ".qhcp project file in $(BUILDDIR)/qthelp, like this:" @echo "# qcollectiongenerator $(BUILDDIR)/qthelp/lyman.qhcp" @echo "To view the help file:" @echo "# assistant -collectionFile $(BUILDDIR)/qthelp/lyman.qhc" devhelp: $(SPHINXBUILD) -b devhelp $(ALLSPHINXOPTS) $(BUILDDIR)/devhelp @echo @echo "Build finished." @echo "To view the help file:" @echo "# mkdir -p $$HOME/.local/share/devhelp/lyman" @echo "# ln -s $(BUILDDIR)/devhelp $$HOME/.local/share/devhelp/lyman" @echo "# devhelp" epub: $(SPHINXBUILD) -b epub $(ALLSPHINXOPTS) $(BUILDDIR)/epub @echo @echo "Build finished. The epub file is in $(BUILDDIR)/epub." latex: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo @echo "Build finished; the LaTeX files are in $(BUILDDIR)/latex." @echo "Run \`make' in that directory to run these through (pdf)latex" \ "(use \`make latexpdf' here to do that automatically)." latexpdf: $(SPHINXBUILD) -b latex $(ALLSPHINXOPTS) $(BUILDDIR)/latex @echo "Running LaTeX files through pdflatex..." $(MAKE) -C $(BUILDDIR)/latex all-pdf @echo "pdflatex finished; the PDF files are in $(BUILDDIR)/latex." text: $(SPHINXBUILD) -b text $(ALLSPHINXOPTS) $(BUILDDIR)/text @echo @echo "Build finished. The text files are in $(BUILDDIR)/text." man: $(SPHINXBUILD) -b man $(ALLSPHINXOPTS) $(BUILDDIR)/man @echo @echo "Build finished. The manual pages are in $(BUILDDIR)/man." texinfo: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo @echo "Build finished. The Texinfo files are in $(BUILDDIR)/texinfo." @echo "Run \`make' in that directory to run these through makeinfo" \ "(use \`make info' here to do that automatically)." info: $(SPHINXBUILD) -b texinfo $(ALLSPHINXOPTS) $(BUILDDIR)/texinfo @echo "Running Texinfo files through makeinfo..." make -C $(BUILDDIR)/texinfo info @echo "makeinfo finished; the Info files are in $(BUILDDIR)/texinfo." gettext: $(SPHINXBUILD) -b gettext $(I18NSPHINXOPTS) $(BUILDDIR)/locale @echo @echo "Build finished. The message catalogs are in $(BUILDDIR)/locale." changes: $(SPHINXBUILD) -b changes $(ALLSPHINXOPTS) $(BUILDDIR)/changes @echo @echo "The overview file is in $(BUILDDIR)/changes." linkcheck: $(SPHINXBUILD) -b linkcheck $(ALLSPHINXOPTS) $(BUILDDIR)/linkcheck @echo @echo "Link check complete; look for any errors in the above output " \ "or in $(BUILDDIR)/linkcheck/output.txt." doctest: $(SPHINXBUILD) -b doctest $(ALLSPHINXOPTS) $(BUILDDIR)/doctest @echo "Testing of doctests in the sources finished, look at the " \ "results in $(BUILDDIR)/doctest/output.txt." ================================================ FILE: doc/README.md ================================================ Building the seaborn docs ========================= Building the docs requires additional dependencies; they can be installed with `pip install seaborn[stats,docs]`. The build process involves conversion of Jupyter notebooks to `rst` files. To facilitate this, you may need to set `NB_KERNEL` environment variable to the name of a kernel on your machine (e.g. `export NB_KERNEL="python3"`). To get a list of available Python kernels, run `jupyter kernelspec list`. After you're set up, run `make notebooks html` from the `doc` directory to convert all notebooks, generate all gallery examples, and build the documentation itself. The site will live in `_build/html`. Run `make clean` to delete the built site and all intermediate files. Run `make -C docstrings clean` or `make -C tutorial clean` to remove intermediate files for the API or tutorial components. If your goal is to obtain an offline copy of the docs for a released version, it may be easier to clone the [website repository](https://github.com/seaborn/seaborn.github.io) or to download a zipfile corresponding to a [specific version](https://github.com/seaborn/seaborn.github.io/tags). ================================================ FILE: doc/_docstrings/FacetGrid.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"ticks\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Calling the constructor requires a long-form data object. This initializes the grid, but doesn't plot anything on it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.FacetGrid(tips)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assign column and/or row variables to add more subplots to the figure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.FacetGrid(tips, col=\"time\", row=\"sex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To draw a plot on every facet, pass a function and the name of one or more columns in the dataframe to :meth:`FacetGrid.map`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\", row=\"sex\")\n", "g.map(sns.scatterplot, \"total_bill\", \"tip\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The variable specification in :meth:`FacetGrid.map` requires a positional argument mapping, but if the function has a ``data`` parameter and accepts named variable assignments, you can also use :meth:`FacetGrid.map_dataframe`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\", row=\"sex\")\n", "g.map_dataframe(sns.histplot, x=\"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Notice how the bins have different widths in each facet. A separate plot is drawn on each facet, so if the plotting function derives any parameters from the data, they may not be shared across facets. You can pass additional keyword arguments to synchronize them. But when possible, using a figure-level function like :func:`displot` will take care of this bookkeeping for you:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\", row=\"sex\")\n", "g.map_dataframe(sns.histplot, x=\"total_bill\", binwidth=2, binrange=(0, 60))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The :class:`FacetGrid` constructor accepts a ``hue`` parameter. Setting this will condition the data on another variable and make multiple plots in different colors. Where possible, label information is tracked so that a single legend can be drawn:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\", hue=\"sex\")\n", "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When ``hue`` is set on the :class:`FacetGrid`, however, a separate plot is drawn for each level of the variable. If the plotting function understands ``hue``, it is better to let it handle that logic. But it is important to ensure that each facet will use the same hue mapping. In the sample ``tips`` data, the ``sex`` column has a categorical datatype, which ensures this. Otherwise, you may want to use the `hue_order` or similar parameter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\")\n", "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\", hue=\"sex\")\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The size and shape of the plot is specified at the level of each subplot using the ``height`` and ``aspect`` parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"day\", height=3.5, aspect=.65)\n", "g.map(sns.histplot, \"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If the variable assigned to ``col`` has many levels, it is possible to \"wrap\" it so that it spans multiple rows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"size\", height=2.5, col_wrap=3)\n", "g.map(sns.histplot, \"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To add horizontal or vertical reference lines on every facet, use :meth:`FacetGrid.refline`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\", margin_titles=True)\n", "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", "g.refline(y=tips[\"tip\"].median())" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can pass custom functions to plot with, or to annotate each facet. Your custom function must use the matplotlib state-machine interface to plot on the \"current\" axes, and it should catch additional keyword arguments:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "def annotate(data, **kws):\n", " n = len(data)\n", " ax = plt.gca()\n", " ax.text(.1, .6, f\"N = {n}\", transform=ax.transAxes)\n", "\n", "g = sns.FacetGrid(tips, col=\"time\")\n", "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", "g.map_dataframe(annotate)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The :class:`FacetGrid` object has some other useful parameters and methods for tweaking the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"sex\", row=\"time\", margin_titles=True)\n", "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", "g.set_axis_labels(\"Total bill ($)\", \"Tip ($)\")\n", "g.set_titles(col_template=\"{col_name} patrons\", row_template=\"{row_name}\")\n", "g.set(xlim=(0, 60), ylim=(0, 12), xticks=[10, 30, 50], yticks=[2, 6, 10])\n", "g.tight_layout()\n", "g.savefig(\"facet_plot.png\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import os\n", "if os.path.exists(\"facet_plot.png\"):\n", " os.remove(\"facet_plot.png\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You also have access to the underlying matplotlib objects for additional tweaking:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"sex\", row=\"time\", margin_titles=True, despine=False)\n", "g.map_dataframe(sns.scatterplot, x=\"total_bill\", y=\"tip\")\n", "g.figure.subplots_adjust(wspace=0, hspace=0)\n", "for (row_val, col_val), ax in g.axes_dict.items():\n", " if row_val == \"Lunch\" and col_val == \"Female\":\n", " ax.set_facecolor(\".95\")\n", " else:\n", " ax.set_facecolor((0, 0, 0, 0))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/JointGrid.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Calling the constructor initializes the figure, but it does not plot anything:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The simplest plotting method, :meth:`JointGrid.plot` accepts a pair of functions (one for the joint axes and one for both marginal axes):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "g.plot(sns.scatterplot, sns.histplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The :meth:`JointGrid.plot` function also accepts additional keyword arguments, but it passes them to both functions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "g.plot(sns.scatterplot, sns.histplot, alpha=.7, edgecolor=\".2\", linewidth=.5)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If you need to pass different keyword arguments to each function, you'll have to invoke :meth:`JointGrid.plot_joint` and :meth:`JointGrid.plot_marginals`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "g.plot_joint(sns.scatterplot, s=100, alpha=.5)\n", "g.plot_marginals(sns.histplot, kde=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You can also set up the grid without assigning any data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You can then plot by accessing the ``ax_joint``, ``ax_marg_x``, and ``ax_marg_y`` attributes, which are :class:`matplotlib.axes.Axes` objects:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid()\n", "x, y = penguins[\"bill_length_mm\"], penguins[\"bill_depth_mm\"]\n", "sns.scatterplot(x=x, y=y, ec=\"b\", fc=\"none\", s=100, linewidth=1.5, ax=g.ax_joint)\n", "sns.histplot(x=x, fill=False, linewidth=2, ax=g.ax_marg_x)\n", "sns.kdeplot(y=y, linewidth=2, ax=g.ax_marg_y)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The plotting methods can use any seaborn functions that accept ``x`` and ``y`` variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "g.plot(sns.regplot, sns.boxplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If the functions accept a ``hue`` variable, you can use it by assigning ``hue`` when you call the constructor:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\")\n", "g.plot(sns.scatterplot, sns.histplot)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Horizontal and/or vertical reference lines can be added to the joint and/or marginal axes using :meth:`JointGrid.refline`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "g.plot(sns.scatterplot, sns.histplot)\n", "g.refline(x=45, y=16)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The figure will always be square (unless you resize it at the matplotlib layer), but its overall size and layout are configurable. The size is controlled by the ``height`` parameter. The relative ratio between the joint and marginal axes is controlled by ``ratio``, and the amount of space between the plots is controlled by ``space``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.JointGrid(height=4, ratio=2, space=.05)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, the ticks on the density axis of the marginal plots are turned off, but this is configurable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.JointGrid(marginal_ticks=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Limits on the two data axes (which are shared across plots) can also be defined when setting up the figure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.JointGrid(xlim=(-2, 5), ylim=(0, 10))" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/Makefile ================================================ rst_files := $(patsubst %.ipynb,../docstrings/%.rst,$(wildcard *.ipynb)) export MPLBACKEND := module://matplotlib_inline.backend_inline docstrings: ${rst_files} ../docstrings/%.rst: %.ipynb ../tools/nb_to_doc.py $*.ipynb ../docstrings @cp -r ../docstrings/$*_files ../generated/ @if [ -f ../generated/seaborn.$*.rst ]; then \ touch ../generated/seaborn.$*.rst; \ fi clean: rm -rf ../docstrings ================================================ FILE: doc/_docstrings/PairGrid.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns; sns.set_theme()\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Calling the constructor sets up a blank grid of subplots with each row and one column corresponding to a numeric variable in the dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "g = sns.PairGrid(penguins)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Passing a bivariate function to :meth:`PairGrid.map` will draw a bivariate plot on every axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins)\n", "g.map(sns.scatterplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Passing separate functions to :meth:`PairGrid.map_diag` and :meth:`PairGrid.map_offdiag` will show each variable's marginal distribution on the diagonal:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins)\n", "g.map_diag(sns.histplot)\n", "g.map_offdiag(sns.scatterplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's also possible to use different functions on the upper and lower triangles of the plot (which are otherwise redundant):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, diag_sharey=False)\n", "g.map_upper(sns.scatterplot)\n", "g.map_lower(sns.kdeplot)\n", "g.map_diag(sns.kdeplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Or to avoid the redundancy altogether:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, diag_sharey=False, corner=True)\n", "g.map_lower(sns.scatterplot)\n", "g.map_diag(sns.kdeplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The :class:`PairGrid` constructor accepts a ``hue`` variable. This variable is passed directly to functions that understand it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, hue=\"species\")\n", "g.map_diag(sns.histplot)\n", "g.map_offdiag(sns.scatterplot)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "But you can also pass matplotlib functions, in which case a groupby is performed internally and a separate plot is drawn for each level:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, hue=\"species\")\n", "g.map_diag(plt.hist)\n", "g.map_offdiag(plt.scatter)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Additional semantic variables can be assigned by passing data vectors directly while mapping the function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, hue=\"species\")\n", "g.map_diag(sns.histplot)\n", "g.map_offdiag(sns.scatterplot, size=penguins[\"sex\"])\n", "g.add_legend(title=\"\", adjust_subtitles=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When using seaborn functions that can implement a numeric hue mapping, you will want to disable mapping of the variable on the diagonal axes. Note that the ``hue`` variable is excluded from the list of variables shown by default:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, hue=\"body_mass_g\")\n", "g.map_diag(sns.histplot, hue=None, color=\".3\")\n", "g.map_offdiag(sns.scatterplot)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ``vars`` parameter can be used to control exactly which variables are used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "variables = [\"body_mass_g\", \"bill_length_mm\", \"flipper_length_mm\"]\n", "g = sns.PairGrid(penguins, hue=\"body_mass_g\", vars=variables)\n", "g.map_diag(sns.histplot, hue=None, color=\".3\")\n", "g.map_offdiag(sns.scatterplot)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The plot need not be square: separate variables can be used to define the rows and columns:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x_vars = [\"body_mass_g\", \"bill_length_mm\", \"bill_depth_mm\", \"flipper_length_mm\"]\n", "y_vars = [\"body_mass_g\"]\n", "g = sns.PairGrid(penguins, hue=\"species\", x_vars=x_vars, y_vars=y_vars)\n", "g.map_diag(sns.histplot, color=\".3\")\n", "g.map_offdiag(sns.scatterplot)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It can be useful to explore different approaches to resolving multiple distributions on the diagonal axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, hue=\"species\")\n", "g.map_diag(sns.histplot, multiple=\"stack\", element=\"step\")\n", "g.map_offdiag(sns.scatterplot)\n", "g.add_legend()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/axes_style.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "dated-mother", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns" ] }, { "cell_type": "markdown", "id": "prospective-sellers", "metadata": {}, "source": [ "Calling with no arguments will return the current defaults for the style parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "recognized-rehabilitation", "metadata": { "tags": [ "show-output" ] }, "outputs": [], "source": [ "sns.axes_style()" ] }, { "cell_type": "markdown", "id": "furnished-irrigation", "metadata": {}, "source": [ "Calling with the name of a predefined style will show those parameter values:" ] }, { "cell_type": "code", "execution_count": null, "id": "coordinate-reward", "metadata": { "tags": [ "show-output" ] }, "outputs": [], "source": [ "sns.axes_style(\"darkgrid\")" ] }, { "cell_type": "markdown", "id": "mediterranean-picking", "metadata": {}, "source": [ "Use the function as a context manager to temporarily change the style of your plots:" ] }, { "cell_type": "code", "execution_count": null, "id": "missing-essence", "metadata": {}, "outputs": [], "source": [ "with sns.axes_style(\"whitegrid\"):\n", " sns.barplot(x=[1, 2, 3], y=[2, 5, 3])" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/barplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "6a6d582b-08c2-4fed-be56-afa1b986943a", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")\n", "penguins = sns.load_dataset(\"penguins\")\n", "flights = sns.load_dataset(\"flights\")" ] }, { "cell_type": "raw", "id": "b53b65b8-5670-4905-aa39-36db04f4b813", "metadata": {}, "source": [ "With long data, assign `x` and `y` to group by a categorical variable and plot aggregated values, with confidence intervals:" ] }, { "cell_type": "code", "execution_count": null, "id": "0f5c3ece-6295-4933-8a87-e80cd604c089", "metadata": {}, "outputs": [], "source": [ "sns.barplot(penguins, x=\"island\", y=\"body_mass_g\")" ] }, { "cell_type": "raw", "id": "ed061d6f-bd3b-4189-bbc7-aed998be05cb", "metadata": {}, "source": [ "Prior to v0.13.0, each bar would have a different color. To replicate this behavior, assign the grouping variable to `hue` as well:" ] }, { "cell_type": "code", "execution_count": null, "id": "3ded2e23-c610-450b-bcd2-1d2ba54db566", "metadata": {}, "outputs": [], "source": [ "sns.barplot(penguins, x=\"body_mass_g\", y=\"island\", hue=\"island\", legend=False)" ] }, { "cell_type": "raw", "id": "e00fa127-4dd4-4565-9897-51317adfea3c", "metadata": {}, "source": [ "When plotting a \"wide-form\" dataframe, each column will be aggregated and represented as a bar:" ] }, { "cell_type": "code", "execution_count": null, "id": "ae7e0f4e-471e-4dee-8913-5e7b67e0a381", "metadata": {}, "outputs": [], "source": [ "flights_wide = flights.pivot(index=\"year\", columns=\"month\", values=\"passengers\")\n", "sns.barplot(flights_wide)" ] }, { "cell_type": "raw", "id": "6020404c-15c6-4c00-9ffd-6c12ba624e52", "metadata": {}, "source": [ "Passing only a series (or dict) will plot each of its values, using the index (or keys) to label the bars:" ] }, { "cell_type": "code", "execution_count": null, "id": "77b2c3eb-c3e4-4d44-929a-27a456da4b88", "metadata": {}, "outputs": [], "source": [ "sns.barplot(flights_wide[\"Jun\"])" ] }, { "cell_type": "raw", "id": "b0c3b101-7649-4014-9ab2-10ff206d39d7", "metadata": {}, "source": [ "With long-form data, you can add a second layer of grouping with `hue`:" ] }, { "cell_type": "code", "execution_count": null, "id": "ac1a28d1-b3bd-4158-86d0-3defc12f8566", "metadata": {}, "outputs": [], "source": [ "sns.barplot(penguins, x=\"island\", y=\"body_mass_g\", hue=\"sex\")" ] }, { "cell_type": "raw", "id": "069ce509-ee0d-42c8-b053-1b4b6d764449", "metadata": {}, "source": [ "Use the error bars to show the standard deviation rather than a confidence interval:" ] }, { "cell_type": "code", "execution_count": null, "id": "10445b78-a74a-4f14-a28b-a9164e592ae4", "metadata": {}, "outputs": [], "source": [ "sns.barplot(penguins, x=\"island\", y=\"body_mass_g\", errorbar=\"sd\")" ] }, { "cell_type": "raw", "id": "6dc3d564-4d26-4753-a2a0-6194b10452bc", "metadata": {}, "source": [ "Use a different aggregation function and disable the error bars:" ] }, { "cell_type": "code", "execution_count": null, "id": "448ba05e-c533-459d-84b6-0fca80e6e3ce", "metadata": {}, "outputs": [], "source": [ "sns.barplot(flights, x=\"year\", y=\"passengers\", estimator=\"sum\", errorbar=None)" ] }, { "cell_type": "raw", "id": "7746220d-b6b4-4ee5-886c-5867db35d4e3", "metadata": {}, "source": [ "Add text labels with each bar's value:" ] }, { "cell_type": "code", "execution_count": null, "id": "e343485c-636e-4b96-b20d-59a7f7155be8", "metadata": {}, "outputs": [], "source": [ "ax = sns.barplot(flights, x=\"year\", y=\"passengers\", estimator=\"sum\", errorbar=None)\n", "ax.bar_label(ax.containers[0], fontsize=10);" ] }, { "cell_type": "raw", "id": "457702c2-9fa6-4021-a19b-f44b39aa0a19", "metadata": {}, "source": [ "Preserve the original scaling of the grouping variable and add annotations in numeric coordinates:" ] }, { "cell_type": "code", "execution_count": null, "id": "08b60118-5830-4fd7-8a66-431c065d57cb", "metadata": {}, "outputs": [], "source": [ "ax = sns.barplot(\n", " flights, x=\"year\", y=\"passengers\",\n", " native_scale=True,\n", " estimator=\"sum\", errorbar=None,\n", ")\n", "ax.plot(1955, 3600, \"*\", markersize=10, color=\"r\")" ] }, { "cell_type": "raw", "id": "206be839-f33b-4ffe-8101-bd98bc5942b8", "metadata": {}, "source": [ "Use `orient` to resolve ambiguity about which variable should group when both are numeric:" ] }, { "cell_type": "code", "execution_count": null, "id": "3aff3c69-3c24-40ad-af12-a507e33f5d3f", "metadata": {}, "outputs": [], "source": [ "sns.barplot(flights, x=\"passengers\", y=\"year\", orient=\"y\")" ] }, { "cell_type": "raw", "id": "90277a3b-1f86-4884-97ad-e5d65df408ef", "metadata": {}, "source": [ "Customize the appearance of the plot using :class:`matplotlib.patches.Rectangle` and :class:`matplotlib.lines.Line2D` keyword arguments:" ] }, { "cell_type": "code", "execution_count": null, "id": "d6f9ac1c-a77d-4ee3-bc5e-fec2071b33df", "metadata": {}, "outputs": [], "source": [ "sns.barplot(\n", " penguins, x=\"body_mass_g\", y=\"island\",\n", " errorbar=(\"pi\", 50), capsize=.4,\n", " err_kws={\"color\": \".5\", \"linewidth\": 2.5},\n", " linewidth=2.5, edgecolor=\".5\", facecolor=(0, 0, 0, 0),\n", ")" ] }, { "cell_type": "raw", "id": "08ef562f-13a3-4da5-a9cf-46deaa543890", "metadata": {}, "source": [ "Use :func:`catplot` to draw faceted bars, which is recommended over working directly with :class:`FacetGrid`:" ] }, { "cell_type": "code", "execution_count": null, "id": "4d23777f-8a69-4c68-ab35-3e6740c61bcf", "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " penguins, kind=\"bar\",\n", " x=\"sex\", y=\"body_mass_g\", col=\"species\",\n", " height=4, aspect=.5,\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "0b6a62b9-eef7-4c85-a1c2-85a58231e6c6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/blend_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "8f97280e-cec8-42b2-a968-4fd4364594f8", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "id": "972edede-df1a-4010-9674-00b864d020e2", "metadata": {}, "source": [ "Pass a list of two colors to interpolate between them:" ] }, { "cell_type": "code", "execution_count": null, "id": "e6ae2547-1042-4ac0-84ea-6f37a0229871", "metadata": {}, "outputs": [], "source": [ "sns.blend_palette([\"b\", \"r\"])" ] }, { "cell_type": "raw", "id": "1d983eac-2dd5-4746-b27f-4dfa19b5e091", "metadata": {}, "source": [ "The color list can be arbitrarily long, and any color format can be used:" ] }, { "cell_type": "code", "execution_count": null, "id": "846b78fd-30ce-4507-93f4-4274122c1987", "metadata": {}, "outputs": [], "source": [ "sns.blend_palette([\"#45a872\", \".8\", \"xkcd:golden\"])" ] }, { "cell_type": "raw", "id": "318fef32-1f83-44d9-9ff9-21fa0231b7c6", "metadata": {}, "source": [ "Return a continuous colormap instead of a discrete palette:" ] }, { "cell_type": "code", "execution_count": null, "id": "f0a05bc3-c60b-47a1-b276-d2e28a4a8226", "metadata": {}, "outputs": [], "source": [ "sns.blend_palette([\"#bdc\", \"#7b9\", \"#47a\"], as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "0473a402-0ec2-4877-81d2-ed6c57aefc77", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/boxenplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "882d215b-88d8-4b5e-ae7a-0e3f6bb53bad", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")\n", "diamonds = sns.load_dataset(\"diamonds\")" ] }, { "cell_type": "raw", "id": "9b8b892e-a96f-46e8-9c5e-8749783608d8", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Draw a single horizontal plot, assigning the data directly to the coordinate variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "391e1162-b438-4486-9a08-60686ee8e96a", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sns.boxenplot(x=diamonds[\"price\"])" ] }, { "cell_type": "raw", "id": "b0c5a469-c709-4333-a8bc-b2cb34f366aa", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Group by a categorical variable, referencing columns in a datafame" ] }, { "cell_type": "code", "execution_count": null, "id": "e30fec18-f127-40a3-bfaf-f71324dd60ec", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sns.boxenplot(data=diamonds, x=\"price\", y=\"clarity\")" ] }, { "cell_type": "raw", "id": "70fe999a-bea5-4b0a-a1a3-474b6696d1be", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Group by another variable, representing it by the color of the boxes. By default, each boxen plot will be \"dodged\" so that they don't overlap; you can also add a small gap between them:" ] }, { "cell_type": "code", "execution_count": null, "id": "eed3239c-57b7-4d76-9fdc-be99257047fd", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "large_diamond = diamonds[\"carat\"].gt(1).rename(\"large_diamond\")\n", "sns.boxenplot(data=diamonds, x=\"price\", y=\"clarity\", hue=large_diamond, gap=.2)" ] }, { "cell_type": "raw", "id": "36030c1c-047b-4f7b-b366-91188b41680e", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "The default rule for choosing each box width represents the percentile covered by the box. Alternatively, you can reduce each box width by a linear factor:" ] }, { "cell_type": "code", "execution_count": null, "id": "d0c1aa43-5e8a-486c-bd6d-3c29d6d23138", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sns.boxenplot(data=diamonds, x=\"price\", y=\"clarity\", width_method=\"linear\")" ] }, { "cell_type": "raw", "id": "062a9fc2-9cbe-4e40-af8c-3fd35f785cd5", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "The `width` parameter itself, on the other hand, determines the width of the largest box:" ] }, { "cell_type": "code", "execution_count": null, "id": "4100a460-fe27-42b7-bbaf-4430a1c1359f", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sns.boxenplot(data=diamonds, x=\"price\", y=\"clarity\", width=.5)" ] }, { "cell_type": "raw", "id": "407874a8-1202-4bcc-9f65-59e1fed29e07", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "There are several different approaches for choosing the number of boxes to draw, including a rule based on the confidence level of the percentile estimate:" ] }, { "cell_type": "code", "execution_count": null, "id": "1aead6a3-6f12-47d3-b472-a39c61867963", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sns.boxenplot(data=diamonds, x=\"price\", y=\"clarity\", k_depth=\"trustworthy\", trust_alpha=0.01)" ] }, { "cell_type": "raw", "id": "71212196-d60e-4682-8dcb-0289956be152", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "The `linecolor` and `linewidth` parameters control the outlines of the boxes, while the `line_kws` parameter controls the line representing the median and the `flier_kws` parameter controls the appearance of the outliers:" ] }, { "cell_type": "code", "execution_count": null, "id": "dd103426-a99f-476b-ae29-a11d52958cdb", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sns.boxenplot(\n", " data=diamonds, x=\"price\", y=\"clarity\",\n", " linewidth=.5, linecolor=\".7\",\n", " line_kws=dict(linewidth=1.5, color=\"#cde\"),\n", " flier_kws=dict(facecolor=\".7\", linewidth=.5),\n", ")" ] }, { "cell_type": "raw", "id": "16f1c534-3316-4752-ae12-f65dee9275cb", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "It is also possible to draw unfilled boxes. With unfilled boxes, all elements will be drawn as line art and follow `hue`, when used:" ] }, { "cell_type": "code", "execution_count": null, "id": "ab6aef09-5bbe-4c01-b6ba-05446982d775", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "sns.boxenplot(data=diamonds, x=\"price\", y=\"clarity\", hue=\"clarity\", fill=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "e059b944-ea59-408d-87bb-4ce65074dab5", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/boxplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "7edcf92f-6c11-4dc4-b684-118b3235d067", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")\n", "titanic = sns.load_dataset(\"titanic\")" ] }, { "cell_type": "raw", "id": "4ca96805-333b-4186-9ad7-dcef4a9aacf5", "metadata": {}, "source": [ "Draw a single horizontal boxplot, assigning the data directly to the coordinate variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "80532f2c-0f34-456c-9d5c-673682385461", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(x=titanic[\"age\"])" ] }, { "cell_type": "raw", "id": "d9e33318-9595-4132-bfbd-8d88905fea79", "metadata": {}, "source": [ "Group by a categorical variable, referencing columns in a dataframe:" ] }, { "cell_type": "code", "execution_count": null, "id": "f1e0a6a4-151d-42d7-a098-ec9b91f20906", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(data=titanic, x=\"age\", y=\"class\")" ] }, { "cell_type": "raw", "id": "d1e0d9e7-2d9b-49e3-8bb3-d97f2de7e733", "metadata": {}, "source": [ "Draw a vertical boxplot with nested grouping by two variables:" ] }, { "cell_type": "code", "execution_count": null, "id": "b8f74dc4-2b59-423a-90a7-dbf900c89251", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(data=titanic, x=\"class\", y=\"age\", hue=\"alive\")" ] }, { "cell_type": "raw", "id": "59aaff3f-2bba-44d1-9901-2dd680bad3ad", "metadata": {}, "source": [ "Draw the boxes as line art and add a small gap between them:" ] }, { "cell_type": "code", "execution_count": null, "id": "6af681be-c49e-4794-8a92-90c58ef330f9", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(data=titanic, x=\"class\", y=\"age\", hue=\"alive\", fill=False, gap=.1)" ] }, { "cell_type": "raw", "id": "db4ef9cb-0f0d-458b-a06d-c537c2b4d733", "metadata": {}, "source": [ "Cover the full range of the data with the whiskers:" ] }, { "cell_type": "code", "execution_count": null, "id": "89aab45a-bc58-44e9-94ac-6a9aa0b20f5e", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(data=titanic, x=\"age\", y=\"deck\", whis=(0, 100))" ] }, { "cell_type": "raw", "id": "3844cc78-19a5-46e3-babd-77d6d7affcf0", "metadata": {}, "source": [ "Draw narrower boxes:" ] }, { "cell_type": "code", "execution_count": null, "id": "399825eb-698a-4464-8a04-505b6bf7edc7", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(data=titanic, x=\"age\", y=\"deck\", width=.5)" ] }, { "cell_type": "raw", "id": "eaf35104-022d-4a20-9b60-f8b24acc7471", "metadata": {}, "source": [ "Modify the color and width of all the line artists:" ] }, { "cell_type": "code", "execution_count": null, "id": "6e9dcaa3-b497-480e-b134-d31e01a7d4c5", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(data=titanic, x=\"age\", y=\"deck\", color=\".8\", linecolor=\"#137\", linewidth=.75)" ] }, { "cell_type": "markdown", "id": "8a188c80-d69f-4a07-9b0d-ca467d2be680", "metadata": {}, "source": [ "Group by a numeric variable and preserve its native scaling:" ] }, { "cell_type": "code", "execution_count": null, "id": "9d73c63f-58a8-4659-96fd-964493ba3a50", "metadata": {}, "outputs": [], "source": [ "ax = sns.boxplot(x=titanic[\"age\"].round(-1), y=titanic[\"fare\"], native_scale=True)\n", "ax.axvline(25, color=\".3\", dashes=(2, 2))" ] }, { "cell_type": "raw", "id": "28536179-8400-462d-bf3e-3d9f353fe03b", "metadata": {}, "source": [ "Customize the plot using parameters of the underlying matplotlib function:" ] }, { "cell_type": "code", "execution_count": null, "id": "66c81b6e-e7fb-46c5-aa7b-f001241569b0", "metadata": {}, "outputs": [], "source": [ "sns.boxplot(\n", " data=titanic, x=\"age\", y=\"class\",\n", " notch=True, showcaps=False,\n", " flierprops={\"marker\": \"x\"},\n", " boxprops={\"facecolor\": (.3, .5, .7, .5)},\n", " medianprops={\"color\": \"r\", \"linewidth\": 2},\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "5d2bb11b-0f4a-4efe-b18b-be34ebf24e49", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/catplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "a8aa6a6a-f6c0-4a6b-9460-2056e58a2e13", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")" ] }, { "cell_type": "raw", "id": "1aef2740-ae6e-4a1b-a588-3ad978e2614d", "metadata": {}, "source": [ "By default, the visual representation will be a jittered strip plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "75a49e26-4318-4963-897c-dc0081aebfb3", "metadata": {}, "outputs": [], "source": [ "df = sns.load_dataset(\"titanic\")\n", "sns.catplot(data=df, x=\"age\", y=\"class\")" ] }, { "cell_type": "markdown", "id": "db1b8f6d-5264-4200-b81a-b0ee64040a1f", "metadata": {}, "source": [ "Use `kind` to select a different representation:" ] }, { "cell_type": "code", "execution_count": null, "id": "75ecd034-8536-4fe4-8852-a3975dba64dc", "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=df, x=\"age\", y=\"class\", kind=\"box\")" ] }, { "cell_type": "markdown", "id": "8aee79a9-b8b3-4129-b6d7-e9e32ae1e634", "metadata": {}, "source": [ "One advantage is that the legend will be automatically placed outside the plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "3798aac6-1ff6-4e36-ad83-4742fcb04159", "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=df, x=\"age\", y=\"class\", hue=\"sex\", kind=\"boxen\")" ] }, { "cell_type": "markdown", "id": "8a3777e1-90b6-4f4d-9e14-247b6dfd64fe", "metadata": {}, "source": [ "Additional keyword arguments get passed through to the underlying seaborn function:" ] }, { "cell_type": "code", "execution_count": null, "id": "afcff2fe-db11-4602-af79-68e4a0380f88", "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=df, x=\"age\", y=\"class\", hue=\"sex\",\n", " kind=\"violin\", bw_adjust=.5, cut=0, split=True,\n", ")" ] }, { "cell_type": "markdown", "id": "a75bf46f-a3d0-4a5d-abcd-b9e85def65b0", "metadata": {}, "source": [ "Assigning a variable to `col` or `row` will automatically create subplots. Control figure size with the `height` and `aspect` parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "835afcf2-ecc9-4edb-9ec8-24484c5b08fb", "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=df, x=\"class\", y=\"survived\", col=\"sex\",\n", " kind=\"bar\", height=4, aspect=.6,\n", ")" ] }, { "cell_type": "markdown", "id": "ecf323fe-1e86-47ff-aa50-e8c297cfa125", "metadata": {}, "source": [ "For single-subplot figures, it is easy to layer different representations:" ] }, { "cell_type": "code", "execution_count": null, "id": "dc5b0fc0-359c-4219-b04e-171d8c7c8051", "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=df, x=\"age\", y=\"class\", kind=\"violin\", color=\".9\", inner=None)\n", "sns.swarmplot(data=df, x=\"age\", y=\"class\", size=3)" ] }, { "cell_type": "raw", "id": "26e06ba4-0457-4597-b699-cb0fe8b2be32", "metadata": {}, "source": [ "Use methods on the returned :class:`FacetGrid` to tweak the presentation:" ] }, { "cell_type": "code", "execution_count": null, "id": "a43f1914-d868-4060-82df-b3d25553d595", "metadata": {}, "outputs": [], "source": [ "g = sns.catplot(\n", " data=df, x=\"who\", y=\"survived\", col=\"class\",\n", " kind=\"bar\", height=4, aspect=.6,\n", ")\n", "g.set_axis_labels(\"\", \"Survival Rate\")\n", "g.set_xticklabels([\"Men\", \"Women\", \"Children\"])\n", "g.set_titles(\"{col_name} {col_var}\")\n", "g.set(ylim=(0, 1))\n", "g.despine(left=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "a529c18c-45bc-4efb-8ae0-c14518349162", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/clustermap.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "ffc1e1d9-fa74-4121-aa87-e1a8665e4c2b", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()" ] }, { "cell_type": "raw", "id": "41b4f602-32af-44f8-bf1a-0f1695c9abbb", "metadata": {}, "source": [ "Plot a heatmap with row and column clustering:" ] }, { "cell_type": "code", "execution_count": null, "id": "c715bd8f-cf5d-4caa-9244-336b3d0248a8", "metadata": {}, "outputs": [], "source": [ "iris = sns.load_dataset(\"iris\")\n", "species = iris.pop(\"species\")\n", "sns.clustermap(iris)" ] }, { "cell_type": "raw", "id": "1cc3134c-579a-442a-97d8-a878651ce90a", "metadata": {}, "source": [ "Change the size and layout of the figure:" ] }, { "cell_type": "code", "execution_count": null, "id": "fd33cf4b-9589-4b9a-a246-0b95bad28c51", "metadata": {}, "outputs": [], "source": [ "sns.clustermap(\n", " iris,\n", " figsize=(7, 5),\n", " row_cluster=False,\n", " dendrogram_ratio=(.1, .2),\n", " cbar_pos=(0, .2, .03, .4)\n", ")" ] }, { "cell_type": "raw", "id": "c5d3408d-f5d6-4045-9d61-15573a981587", "metadata": {}, "source": [ "Add colored labels to identify observations:" ] }, { "cell_type": "code", "execution_count": null, "id": "79d3fe52-6146-4f33-a39a-1d4a47243ea5", "metadata": {}, "outputs": [], "source": [ "lut = dict(zip(species.unique(), \"rbg\"))\n", "row_colors = species.map(lut)\n", "sns.clustermap(iris, row_colors=row_colors)" ] }, { "cell_type": "raw", "id": "f2f944e2-36cd-4653-86b4-6d2affec13d6", "metadata": {}, "source": [ "Use a different colormap and adjust the limits of the color range:" ] }, { "cell_type": "code", "execution_count": null, "id": "6137c7ad-db92-47b8-9d00-3228c4e1f7df", "metadata": {}, "outputs": [], "source": [ "sns.clustermap(iris, cmap=\"mako\", vmin=0, vmax=10)" ] }, { "cell_type": "raw", "id": "93f96d1c-9d04-464f-93c9-4319caa8504a", "metadata": {}, "source": [ "Use differente clustering parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "f9e76bde-a222-4eca-971f-54f56ad53281", "metadata": {}, "outputs": [], "source": [ "sns.clustermap(iris, metric=\"correlation\", method=\"single\")" ] }, { "cell_type": "raw", "id": "ea6ed3fd-188d-4244-adac-ec0169c02205", "metadata": {}, "source": [ "Standardize the data within the columns:" ] }, { "cell_type": "code", "execution_count": null, "id": "e5f744c4-b959-4ed1-b2cf-6046c9214568", "metadata": {}, "outputs": [], "source": [ "sns.clustermap(iris, standard_scale=1)" ] }, { "cell_type": "raw", "id": "7ca72242-4eb0-4f8e-b0c0-d1ef7166b738", "metadata": {}, "source": [ "Normalize the data within rows:" ] }, { "cell_type": "code", "execution_count": null, "id": "33815c4c-9bae-4226-bd11-3dfdb7ecab2b", "metadata": {}, "outputs": [], "source": [ "sns.clustermap(iris, z_score=0, cmap=\"vlag\", center=0)" ] }, { "cell_type": "code", "execution_count": null, "id": "0f37d57a-b049-4665-9c24-4d5fbbca00ba", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/color_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Calling with no arguments returns all colors from the current default\n", "color cycle:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Other variants on the seaborn categorical color palette can be referenced by name:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"pastel\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return a specified number of evenly spaced hues in the \"HUSL\" system:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"husl\", 9)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return all unique colors in a categorical Color Brewer palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Set2\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return a diverging Color Brewer palette as a continuous colormap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Spectral\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return one of the perceptually-uniform palettes included in seaborn as a discrete palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"flare\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return one of the perceptually-uniform palettes included in seaborn as a continuous colormap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"flare\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return a customized cubehelix color palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"ch:s=.25,rot=-.25\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return a light sequential gradient:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"light:#5A9\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return a reversed dark sequential gradient:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"dark:#5A9_r\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Return a blend gradient between two endpoints:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"blend:#7AB,#EDA\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use as a context manager to change the default qualitative color palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "x, y = list(range(10)), [0] * 10\n", "hue = list(map(str, x))" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with sns.color_palette(\"Set3\"):\n", " sns.relplot(x=x, y=y, hue=hue, s=500, legend=False, height=1.3, aspect=4)\n", "\n", "sns.relplot(x=x, y=y, hue=hue, s=500, legend=False, height=1.3, aspect=4)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "See the underlying color values as hex codes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "show-output" ] }, "outputs": [], "source": [ "print(sns.color_palette(\"pastel6\").as_hex())" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/countplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2fdf0f63-d515-4cb8-b3e0-62cac7852b12", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")\n", "titanic = sns.load_dataset(\"titanic\")" ] }, { "cell_type": "raw", "id": "af16d745-734a-4f11-9f8f-fa54deadfb12", "metadata": {}, "source": [ "Show the count of value for a single categorical variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "6e9d0485-870d-4841-9c84-6e0bacbde7db", "metadata": {}, "outputs": [], "source": [ "sns.countplot(titanic, x=\"class\")" ] }, { "cell_type": "raw", "id": "173f47c4-d5fb-4fc0-bdbd-ec228419d451", "metadata": {}, "source": [ "Group by a second variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "26f73c00-a2b3-45c3-b3cd-2babe0a81894", "metadata": {}, "outputs": [], "source": [ "sns.countplot(titanic, x=\"class\", hue=\"survived\")" ] }, { "cell_type": "raw", "id": "377bfb01-64a2-4f07-b06b-fb1a4f7c3b12", "metadata": {}, "source": [ "Normalize the counts to show percentages:" ] }, { "cell_type": "code", "execution_count": null, "id": "7267aefc-f2bc-4a64-956a-bb25013ca9ec", "metadata": {}, "outputs": [], "source": [ "sns.countplot(titanic, x=\"class\", hue=\"survived\", stat=\"percent\")" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/cubehelix_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "60aebc68-2c7c-4af5-a159-8421e1f94ba6", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "id": "242b3d42-1f10-4da2-9ef9-af06f7fbd724", "metadata": {}, "source": [ "Return a discrete palette with default parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "6526accb-9930-4e39-9f58-1ca2941c1c9d", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette()" ] }, { "cell_type": "raw", "id": "887a40f0-d949-41fa-9a43-0ee246c9a077", "metadata": {}, "source": [ "Increase the number of colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "02833290-b1ee-46df-a2a0-8268fba94628", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(8)" ] }, { "cell_type": "raw", "id": "a9eb86c7-f92e-4422-ae62-a2ef136e7e35", "metadata": {}, "source": [ "Return a continuous colormap rather than a discrete palette:" ] }, { "cell_type": "code", "execution_count": null, "id": "a460efc2-cf0a-46bf-a12f-12870afce8a5", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(as_cmap=True)" ] }, { "cell_type": "raw", "id": "5b84aa6c-ad79-45b1-a7d2-44b7ecba5f7d", "metadata": {}, "source": [ "Change the starting point of the helix:" ] }, { "cell_type": "code", "execution_count": null, "id": "70ee079a-e760-4d43-8447-648fd236ab15", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(start=2)" ] }, { "cell_type": "raw", "id": "5e21fa22-9ac3-4354-8694-967f2447b286", "metadata": {}, "source": [ "Change the amount of rotation in the helix:" ] }, { "cell_type": "code", "execution_count": null, "id": "ddb1b8c7-8933-4317-827f-4f10d2b4cecc", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(rot=.2)" ] }, { "cell_type": "raw", "id": "fa91aff7-54e7-4754-a13c-b629dfc33e8f", "metadata": {}, "source": [ "Rotate in the reverse direction:" ] }, { "cell_type": "code", "execution_count": null, "id": "548a3942-48ae-40d2-abb7-acc2ffd71601", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(rot=-.2)" ] }, { "cell_type": "raw", "id": "e7188a1b-183f-4b04-93a0-975c27fe408e", "metadata": {}, "source": [ "Apply a nonlinearity to the luminance ramp:" ] }, { "cell_type": "code", "execution_count": null, "id": "9ced54ff-a396-451e-b17f-2366b56f920b", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(gamma=.5)" ] }, { "cell_type": "raw", "id": "bc82ce48-2df3-464e-b70e-a1d73d0432c6", "metadata": {}, "source": [ "Increase the saturation of the colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "a38b91a8-3fdc-4293-a3ea-71b4006cd2a1", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(hue=1)" ] }, { "cell_type": "raw", "id": "f8d23ba1-013a-489f-94c4-f2080bfdae87", "metadata": {}, "source": [ "Change the luminance at the start and end points:" ] }, { "cell_type": "code", "execution_count": null, "id": "a4f05a16-18f0-4c14-99a4-57a0734aad02", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(dark=.25, light=.75)" ] }, { "cell_type": "raw", "id": "0bfcc5d9-05ba-4715-94ac-8d430d9416c2", "metadata": {}, "source": [ "Reverse the direction of the luminance ramp:" ] }, { "cell_type": "code", "execution_count": null, "id": "74563491-5448-42c3-86c5-f5d55ce6924c", "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(reverse=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "94a83211-8b8e-4e60-8365-9600e71ddc5d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/dark_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "5cd1cbb8-ba1a-460b-8e3a-bc285867f1d1", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "id": "b157eb25-015f-4dd6-9785-83ba19cf4f94", "metadata": {}, "source": [ "Define a sequential ramp from a dark gray to a specified color:" ] }, { "cell_type": "code", "execution_count": null, "id": "5b655d28-9855-4528-8b8e-a6c50288fd1b", "metadata": {}, "outputs": [], "source": [ "sns.dark_palette(\"seagreen\")" ] }, { "cell_type": "raw", "id": "50053b26-112a-4378-8ef0-9be0fb565ec7", "metadata": {}, "source": [ "Specify the color with a hex code:" ] }, { "cell_type": "code", "execution_count": null, "id": "74ae0d17-f65b-4bcf-ae66-d97d46964d5c", "metadata": {}, "outputs": [], "source": [ "sns.dark_palette(\"#79C\")" ] }, { "cell_type": "raw", "id": "eea376a2-fdf5-40e4-a187-3a28af529072", "metadata": {}, "source": [ "Specify the color from the husl system:" ] }, { "cell_type": "code", "execution_count": null, "id": "66e451ee-869a-41ea-8dc5-4240b11e7be5", "metadata": {}, "outputs": [], "source": [ "sns.dark_palette((20, 60, 50), input=\"husl\")" ] }, { "cell_type": "raw", "id": "e4f44dcd-cf49-4920-ac05-b4db67870363", "metadata": {}, "source": [ "Increase the number of colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "75985f07-de92-4d8b-89d5-caf445b9375e", "metadata": {}, "outputs": [], "source": [ "sns.dark_palette(\"xkcd:golden\", 8)" ] }, { "cell_type": "raw", "id": "34687ae8-fd6d-427a-a639-208f19e61122", "metadata": {}, "source": [ "Return a continuous colormap rather than a discrete palette:" ] }, { "cell_type": "code", "execution_count": null, "id": "2c342db4-7f97-40f5-934e-9a82201890d1", "metadata": {}, "outputs": [], "source": [ "sns.dark_palette(\"#b285bc\", as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "e7ebe64b-25fa-4c52-9ebe-fdcbba0ee51e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/displot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns; sns.set_theme(style=\"ticks\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The default plot kind is a histogram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.displot(data=penguins, x=\"flipper_length_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use the ``kind`` parameter to select a different representation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", kind=\"kde\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are three main plot kinds; in addition to histograms and kernel density estimates (KDEs), you can also draw empirical cumulative distribution functions (ECDFs):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", kind=\"ecdf\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "While in histogram mode, it is also possible to add a KDE curve:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", kde=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To draw a bivariate plot, assign both ``x`` and ``y``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Currently, bivariate plots are available only for histograms and KDEs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", kind=\"kde\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "For each kind of plot, you can also show individual observations with a marginal \"rug\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.displot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", kind=\"kde\", rug=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Each kind of plot can be drawn separately for subsets of data using ``hue`` mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Additional keyword arguments are passed to the appropriate underlying plotting function, allowing for further customization:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The figure is constructed using a :class:`FacetGrid`, meaning that you can also show subsets on distinct subplots, or \"facets\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", col=\"sex\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Because the figure is drawn with a :class:`FacetGrid`, you control its size and shape with the ``height`` and ``aspect`` parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(\n", " data=penguins, y=\"flipper_length_mm\", hue=\"sex\", col=\"species\",\n", " kind=\"ecdf\", height=4, aspect=.7,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The function returns the :class:`FacetGrid` object with the plot, and you can use the methods on this object to customize it further:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.displot(\n", " data=penguins, y=\"flipper_length_mm\", hue=\"sex\", col=\"species\",\n", " kind=\"kde\", height=4, aspect=.7,\n", ")\n", "g.set_axis_labels(\"Density (a.u.)\", \"Flipper length (mm)\")\n", "g.set_titles(\"{col_name} penguins\")" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/diverging_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "01295cb6-cc7a-4c6d-94cf-9b0e6cde9fa7", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "id": "84880848-0805-4c41-999a-50808b397275", "metadata": {}, "source": [ "Generate diverging ramps from blue to red through white:" ] }, { "cell_type": "code", "execution_count": null, "id": "643b3e07-8365-46e3-b033-af7a2fdcd158", "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(240, 20)" ] }, { "cell_type": "raw", "id": "5ae53941-d9d9-4b5a-8abc-173911ebee74", "metadata": {}, "source": [ "Change the center color to be dark:" ] }, { "cell_type": "code", "execution_count": null, "id": "41f03771-8fb2-46f6-93c5-5a0e28be625c", "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(240, 20, center=\"dark\")" ] }, { "cell_type": "raw", "id": "0aeb2402-2cbe-4546-a354-f1f501f762ae", "metadata": {}, "source": [ "Return a continuous colormap rather than a discrete palette:" ] }, { "cell_type": "code", "execution_count": null, "id": "64d335a5-f8b2-433f-a83f-5aeff7db583a", "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(240, 20, as_cmap=True)" ] }, { "cell_type": "raw", "id": "77223a07-8492-4056-a0f7-14e133e3ce2c", "metadata": {}, "source": [ "Increase the amount of separation around the center value:" ] }, { "cell_type": "code", "execution_count": null, "id": "82472c1e-4b16-40eb-be1d-480bbd2aa702", "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(240, 20, sep=30, as_cmap=True)" ] }, { "cell_type": "raw", "id": "966e8594-b458-414c-a7b0-3e804ce407bf", "metadata": {}, "source": [ "Use a magenta-to-green palette instead:" ] }, { "cell_type": "code", "execution_count": null, "id": "a03f8ede-b424-4e06-beb6-cf63c94bcd9e", "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(280, 150)" ] }, { "cell_type": "raw", "id": "b3b17689-58e2-4065-9d52-1cf5ebcd4e89", "metadata": {}, "source": [ "Decrease the saturation of the endpoints:" ] }, { "cell_type": "code", "execution_count": null, "id": "02aaa009-f257-4fc7-a2de-40fbb1464490", "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(280, 150, s=50)" ] }, { "cell_type": "raw", "id": "db75ca48-ba72-4ca2-8480-bc72c20a70cc", "metadata": {}, "source": [ "Decrease the lightness of the endpoints:" ] }, { "cell_type": "code", "execution_count": null, "id": "89e3bcb1-a17c-4465-830f-46043cb6c322", "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(280, 150, l=35)" ] }, { "cell_type": "code", "execution_count": null, "id": "4e42452a-a485-43e7-bbc3-338db58e4637", "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "code", "execution_count": null, "id": "e19f523f-c2f7-489a-ba00-326810e31a67", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/ecdfplot.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Plot a univariate distribution along the x axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns; sns.set_theme()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.ecdfplot(data=penguins, x=\"flipper_length_mm\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Flip the plot by assigning the data variable to the y axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.ecdfplot(data=penguins, y=\"flipper_length_mm\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If neither `x` nor `y` is assigned, the dataset is treated as wide-form, and a histogram is drawn for each numeric column:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.ecdfplot(data=penguins.filter(like=\"bill_\", axis=\"columns\"))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also draw multiple histograms from a long-form dataset with hue mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.ecdfplot(data=penguins, x=\"bill_length_mm\", hue=\"species\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default distribution statistic is normalized to show a proportion, but you can show absolute counts or percents instead:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.ecdfplot(data=penguins, x=\"bill_length_mm\", hue=\"species\", stat=\"count\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's also possible to plot the empirical complementary CDF (1 - CDF):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.ecdfplot(data=penguins, x=\"bill_length_mm\", hue=\"species\", complementary=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/heatmap.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "987b9549-532e-4091-a6cf-007d1b23e825", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()" ] }, { "cell_type": "raw", "id": "2c78ca60-e232-44f6-956b-b86b472b1c28", "metadata": {}, "source": [ "Pass a :class:`DataFrame` to plot with indices as row/column labels:" ] }, { "cell_type": "code", "execution_count": null, "id": "fad17798-c2e3-4334-abf0-0d46153971fa", "metadata": {}, "outputs": [], "source": [ "glue = sns.load_dataset(\"glue\").pivot(index=\"Model\", columns=\"Task\", values=\"Score\")\n", "sns.heatmap(glue)" ] }, { "cell_type": "raw", "id": "f3255c5f-2477-4d13-b4c2-7e56380e9cc2", "metadata": {}, "source": [ "Use `annot` to represent the cell values with text:" ] }, { "cell_type": "code", "execution_count": null, "id": "3c9f3c73-c8bc-426e-bc67-dec8f807082e", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(glue, annot=True)" ] }, { "cell_type": "raw", "id": "bc412da8-866a-49b7-8496-01fbf06dd908", "metadata": {}, "source": [ "Control the annotations with a formatting string:" ] }, { "cell_type": "code", "execution_count": null, "id": "ac952d0d-9187-4dff-a560-88430076851a", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(glue, annot=True, fmt=\".1f\")" ] }, { "cell_type": "raw", "id": "5eb12725-e9ee-4df0-9708-243d7e0a77b5", "metadata": {}, "source": [ "Use a separate dataframe for the annotations:" ] }, { "cell_type": "code", "execution_count": null, "id": "1189a37f-9f74-455a-a09a-c22e056d8ba7", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(glue, annot=glue.rank(axis=\"columns\"))" ] }, { "cell_type": "raw", "id": "253dfb7f-aa12-4716-adc2-3a38b003b2c3", "metadata": {}, "source": [ "Add lines between cells:" ] }, { "cell_type": "code", "execution_count": null, "id": "5cac673e-9b86-490b-9e67-ec0cf865bede", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(glue, annot=True, linewidth=.5)" ] }, { "cell_type": "raw", "id": "b7d3659c-f996-4af3-a612-430d97799c72", "metadata": {}, "source": [ "Select a different colormap by name:" ] }, { "cell_type": "code", "execution_count": null, "id": "86806d72-e784-430e-8320-48f2c91115bb", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(glue, cmap=\"crest\")" ] }, { "cell_type": "raw", "id": "8336fd53-3841-458f-b26c-411efff54d45", "metadata": {}, "source": [ "Or pass a colormap object:" ] }, { "cell_type": "code", "execution_count": null, "id": "9944ff33-991f-4138-a951-e3015c0326f1", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(glue, cmap=sns.cubehelix_palette(as_cmap=True))" ] }, { "cell_type": "raw", "id": "52cc4dba-b86a-4da8-9cbd-3f8aa06b43b4", "metadata": {}, "source": [ "Set the colormap norm (data values corresponding to minimum and maximum points):" ] }, { "cell_type": "code", "execution_count": null, "id": "b4ddb41e-c075-41a5-8afe-422ad6d105bf", "metadata": {}, "outputs": [], "source": [ "sns.heatmap(glue, vmin=50, vmax=100)" ] }, { "cell_type": "raw", "id": "6e828517-a532-49b1-be11-eda47c50cc37", "metadata": {}, "source": [ "Use methods on the :class:`matplotlib.axes.Axes` object to tweak the plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "1aab26fc-2de4-4d4f-ad08-487809573deb", "metadata": {}, "outputs": [], "source": [ "ax = sns.heatmap(glue, annot=True)\n", "ax.set(xlabel=\"\", ylabel=\"\")\n", "ax.xaxis.tick_top()" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/histplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"white\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assign a variable to ``x`` to plot a univariate distribution along the x axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.histplot(data=penguins, x=\"flipper_length_mm\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Flip the plot by assigning the data variable to the y axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=penguins, y=\"flipper_length_mm\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Check how well the histogram represents the data by specifying a different bin width:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=penguins, x=\"flipper_length_mm\", binwidth=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can also define the total number of bins to use:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=penguins, x=\"flipper_length_mm\", bins=30)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add a kernel density estimate to smooth the histogram, providing complementary information about the shape of the distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=penguins, x=\"flipper_length_mm\", kde=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If neither `x` nor `y` is assigned, the dataset is treated as wide-form, and a histogram is drawn for each numeric column:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=penguins)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can otherwise draw multiple histograms from a long-form dataset with hue mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default approach to plotting multiple distributions is to \"layer\" them, but you can also \"stack\" them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Overlapping bars can be hard to visually resolve. A different approach would be to draw a step function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(penguins, x=\"flipper_length_mm\", hue=\"species\", element=\"step\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can move even farther away from bars by drawing a polygon with vertices in the center of each bin. This may make it easier to see the shape of the distribution, but use with caution: it will be less obvious to your audience that they are looking at a histogram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(penguins, x=\"flipper_length_mm\", hue=\"species\", element=\"poly\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To compare the distribution of subsets that differ substantially in size, use independent density normalization:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(\n", " penguins, x=\"bill_length_mm\", hue=\"island\", element=\"step\",\n", " stat=\"density\", common_norm=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's also possible to normalize so that each bar's height shows a probability, proportion, or percent, which make more sense for discrete variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.histplot(data=tips, x=\"size\", stat=\"percent\", discrete=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can even draw a histogram over categorical variables (although this is an experimental feature):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=tips, x=\"day\", shrink=.8)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When using a ``hue`` semantic with discrete data, it can make sense to \"dodge\" the levels:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=tips, x=\"day\", hue=\"sex\", multiple=\"dodge\", shrink=.8)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Real-world data is often skewed. For heavily skewed distributions, it's better to define the bins in log space. Compare:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "planets = sns.load_dataset(\"planets\")\n", "sns.histplot(data=planets, x=\"distance\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To the log-scale version:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=planets, x=\"distance\", log_scale=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "There are also a number of options for how the histogram appears. You can show unfilled bars:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=planets, x=\"distance\", log_scale=True, fill=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Or an unfilled step function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(data=planets, x=\"distance\", log_scale=True, element=\"step\", fill=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Step functions, especially when unfilled, make it easy to compare cumulative histograms:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(\n", " data=planets, x=\"distance\", hue=\"method\",\n", " hue_order=[\"Radial Velocity\", \"Transit\"],\n", " log_scale=True, element=\"step\", fill=False,\n", " cumulative=True, stat=\"density\", common_norm=False,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "When both ``x`` and ``y`` are assigned, a bivariate histogram is computed and shown as a heatmap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(penguins, x=\"bill_depth_mm\", y=\"body_mass_g\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's possible to assign a ``hue`` variable too, although this will not work well if data from the different levels have substantial overlap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(penguins, x=\"bill_depth_mm\", y=\"body_mass_g\", hue=\"species\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Multiple color maps can make sense when one of the variables is discrete:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(\n", " penguins, x=\"bill_depth_mm\", y=\"species\", hue=\"species\", legend=False\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The bivariate histogram accepts all of the same options for computation as its univariate counterpart, using tuples to parametrize ``x`` and ``y`` independently:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(\n", " planets, x=\"year\", y=\"distance\",\n", " bins=30, discrete=(True, False), log_scale=(False, True),\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default behavior makes cells with no observations transparent, although this can be disabled: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(\n", " planets, x=\"year\", y=\"distance\",\n", " bins=30, discrete=(True, False), log_scale=(False, True),\n", " thresh=None,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "It's also possible to set the threshold and colormap saturation point in terms of the proportion of cumulative counts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(\n", " planets, x=\"year\", y=\"distance\",\n", " bins=30, discrete=(True, False), log_scale=(False, True),\n", " pthresh=.05, pmax=.9,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To annotate the colormap, add a colorbar:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.histplot(\n", " planets, x=\"year\", y=\"distance\",\n", " bins=30, discrete=(True, False), log_scale=(False, True),\n", " cbar=True, cbar_kws=dict(shrink=.75),\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/hls_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "158cd1cf-6b30-4054-b32f-a166fcb883be", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "id": "c81b86cb-fb4e-418b-8d2f-6cd10601ac5a", "metadata": {}, "source": [ "By default, return 6 colors with identical lightness and saturation and evenly-sampled hues:" ] }, { "cell_type": "code", "execution_count": null, "id": "6c3eaeaf-88eb-4012-96ea-41b328fa98b9", "metadata": {}, "outputs": [], "source": [ "sns.hls_palette()" ] }, { "cell_type": "raw", "id": "f7624b0b-2311-45de-b6a5-fc07132ce455", "metadata": {}, "source": [ "Increase the number of colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "555c29d1-6972-4a19-ad32-957fb7545634", "metadata": {}, "outputs": [], "source": [ "sns.hls_palette(8)" ] }, { "cell_type": "raw", "id": "24713fa6-e485-4358-9ffc-d40bd9543caa", "metadata": {}, "source": [ "Decrease the lightness:" ] }, { "cell_type": "code", "execution_count": null, "id": "b6f80b4c-f7b4-4deb-a119-cdf6cfe1f7b5", "metadata": {}, "outputs": [], "source": [ "sns.hls_palette(l=.3)" ] }, { "cell_type": "raw", "id": "e521b514-5572-43e8-95ae-a20cc30169b8", "metadata": {}, "source": [ "Decrease the saturation:" ] }, { "cell_type": "code", "execution_count": null, "id": "f88bd038-0c9c-48b1-92b0-d272a9c199f4", "metadata": {}, "outputs": [], "source": [ "sns.hls_palette(s=.3)" ] }, { "cell_type": "raw", "id": "92a2212c-2177-4c82-8a5e-9dd788e9f87c", "metadata": {}, "source": [ "Change the start-point for hue sampling:" ] }, { "cell_type": "code", "execution_count": null, "id": "f8da8fbc-551c-4896-b1b8-04203e740d78", "metadata": {}, "outputs": [], "source": [ "sns.hls_palette(h=.5)" ] }, { "cell_type": "raw", "id": "87780608-1f5a-409f-b31f-6a31a599f122", "metadata": {}, "source": [ "Return a continuous colormap. Notice the perceptual discontinuities, especially around yellow, cyan, and magenta: " ] }, { "cell_type": "code", "execution_count": null, "id": "4c622b3b-70d7-4139-8389-f3d0d4addd66", "metadata": {}, "outputs": [], "source": [ "sns.hls_palette(as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "3a83c1de-88c5-4327-abd2-19e8f3642052", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/husl_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "a6794650-f28f-40eb-95a7-3f0e5c4b332d", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "id": "fab2f86e-45d4-4982-ade7-0a5ea6d762d1", "metadata": {}, "source": [ "By default, return 6 colors with identical lightness and saturation and evenly-sampled hues:" ] }, { "cell_type": "code", "execution_count": null, "id": "b220950e-0ca2-4101-b56a-14eebe8ee8d0", "metadata": {}, "outputs": [], "source": [ "sns.husl_palette()" ] }, { "cell_type": "raw", "id": "c5e4a2e3-e6b8-42bf-be19-348ff7ae2798", "metadata": {}, "source": [ "Increase the number of colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "7d0af740-cfca-49fb-a472-1daa4ccb3f3a", "metadata": {}, "outputs": [], "source": [ "sns.husl_palette(8)" ] }, { "cell_type": "raw", "id": "1a7189f2-2a26-446a-90e7-cf41dcac4f25", "metadata": {}, "source": [ "Decrease the lightness:" ] }, { "cell_type": "code", "execution_count": null, "id": "43af79c7-f497-41e5-874a-83eed99500f3", "metadata": {}, "outputs": [], "source": [ "sns.husl_palette(l=.4)" ] }, { "cell_type": "raw", "id": "6d4099b7-5115-4365-b120-33a345581f5d", "metadata": {}, "source": [ "Decrease the saturation:" ] }, { "cell_type": "code", "execution_count": null, "id": "52c1afc7-d982-4199-b218-222aa94563c5", "metadata": {}, "outputs": [], "source": [ "sns.husl_palette(s=.4)" ] }, { "cell_type": "raw", "id": "d26131ac-0d11-48c5-88b1-4e5cf9383000", "metadata": {}, "source": [ "Change the start-point for hue sampling:" ] }, { "cell_type": "code", "execution_count": null, "id": "d72f06a0-13e0-47f7-bc70-4c5935eaa130", "metadata": {}, "outputs": [], "source": [ "sns.husl_palette(h=.5)" ] }, { "cell_type": "raw", "id": "7e6c3c19-41d3-4315-b03e-909d201d0e76", "metadata": {}, "source": [ "Return a continuous colormap:" ] }, { "cell_type": "code", "execution_count": null, "id": "49c18838-0589-496f-9a61-635195c07f61", "metadata": {}, "outputs": [], "source": [ "sns.husl_palette(as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "c710a557-8e84-44cb-ab4c-baabcc4fd328", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/jointplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"white\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In the simplest invocation, assign ``x`` and ``y`` to create a scatterplot (using :func:`scatterplot`) with marginal histograms (using :func:`histplot`):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a ``hue`` variable will add conditional colors to the scatterplot and draw separate density curves (using :func:`kdeplot`) on the marginal axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Several different approaches to plotting are available through the ``kind`` parameter. Setting ``kind=\"kde\"`` will draw both bivariate and univariate KDEs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Set ``kind=\"reg\"`` to add a linear regression fit (using :func:`regplot`) and univariate KDE curves:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"reg\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "There are also two options for bin-based visualization of the joint distribution. The first, with ``kind=\"hist\"``, uses :func:`histplot` on all of the axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"hist\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Alternatively, setting ``kind=\"hex\"`` will use :meth:`matplotlib.axes.Axes.hexbin` to compute a bivariate histogram using hexagonal bins:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"hex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Additional keyword arguments can be passed down to the underlying plots:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(\n", " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " marker=\"+\", s=100, marginal_kws=dict(bins=25, fill=False),\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use :class:`JointGrid` parameters to control the size and layout of the figure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", height=5, ratio=2, marginal_ticks=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To add more layers onto the plot, use the methods on the :class:`JointGrid` object that :func:`jointplot` returns:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "g.plot_joint(sns.kdeplot, color=\"r\", zorder=0, levels=6)\n", "g.plot_marginals(sns.rugplot, color=\"r\", height=-.15, clip_on=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/kdeplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns; sns.set_theme()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot a univariate distribution along the x axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.kdeplot(data=tips, x=\"total_bill\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Flip the plot by assigning the data variable to the y axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=tips, y=\"total_bill\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot distributions for each column of a wide-form dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "iris = sns.load_dataset(\"iris\")\n", "sns.kdeplot(data=iris)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use less smoothing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=tips, x=\"total_bill\", bw_adjust=.2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use more smoothing, but don't smooth past the extreme data points:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ax= sns.kdeplot(data=tips, x=\"total_bill\", bw_adjust=5, cut=0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot conditional distributions with hue mapping of a second variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"time\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "\"Stack\" the conditional distributions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"time\", multiple=\"stack\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Normalize the stacked distribution at each value in the grid:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"time\", multiple=\"fill\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Estimate the cumulative distribution function(s), normalizing each subset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(\n", " data=tips, x=\"total_bill\", hue=\"time\",\n", " cumulative=True, common_norm=False, common_grid=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Estimate distribution from aggregated data, using weights:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips_agg = (tips\n", " .groupby(\"size\")\n", " .agg(total_bill=(\"total_bill\", \"mean\"), n=(\"total_bill\", \"count\"))\n", ")\n", "sns.kdeplot(data=tips_agg, x=\"total_bill\", weights=\"n\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Map the data variable with log scaling:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "diamonds = sns.load_dataset(\"diamonds\")\n", "sns.kdeplot(data=diamonds, x=\"price\", log_scale=True)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Use numeric hue mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=tips, x=\"total_bill\", hue=\"size\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Modify the appearance of the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(\n", " data=tips, x=\"total_bill\", hue=\"size\",\n", " fill=True, common_norm=False, palette=\"crest\",\n", " alpha=.5, linewidth=0,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Plot a bivariate distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "geyser = sns.load_dataset(\"geyser\")\n", "sns.kdeplot(data=geyser, x=\"waiting\", y=\"duration\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Map a third variable with a hue semantic to show conditional distributions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=geyser, x=\"waiting\", y=\"duration\", hue=\"kind\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Show filled contours:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(\n", " data=geyser, x=\"waiting\", y=\"duration\", hue=\"kind\", fill=True,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Show fewer contour levels, covering less of the distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(\n", " data=geyser, x=\"waiting\", y=\"duration\", hue=\"kind\",\n", " levels=5, thresh=.2,\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Fill the axes extent with a smooth distribution, using a different colormap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(\n", " data=geyser, x=\"waiting\", y=\"duration\",\n", " fill=True, thresh=0, levels=100, cmap=\"mako\",\n", ")" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/light_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "5cd1cbb8-ba1a-460b-8e3a-bc285867f1d1", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "raw", "id": "b157eb25-015f-4dd6-9785-83ba19cf4f94", "metadata": {}, "source": [ "Define a sequential ramp from a light gray to a specified color:" ] }, { "cell_type": "code", "execution_count": null, "id": "851a4742-6276-4383-b17e-480beb896877", "metadata": {}, "outputs": [], "source": [ "sns.light_palette(\"seagreen\")" ] }, { "cell_type": "raw", "id": "50053b26-112a-4378-8ef0-9be0fb565ec7", "metadata": {}, "source": [ "Specify the color with a hex code:" ] }, { "cell_type": "code", "execution_count": null, "id": "74ae0d17-f65b-4bcf-ae66-d97d46964d5c", "metadata": {}, "outputs": [], "source": [ "sns.light_palette(\"#79C\")" ] }, { "cell_type": "raw", "id": "eea376a2-fdf5-40e4-a187-3a28af529072", "metadata": {}, "source": [ "Specify the color from the husl system:" ] }, { "cell_type": "code", "execution_count": null, "id": "66e451ee-869a-41ea-8dc5-4240b11e7be5", "metadata": {}, "outputs": [], "source": [ "sns.light_palette((20, 60, 50), input=\"husl\")" ] }, { "cell_type": "raw", "id": "e4f44dcd-cf49-4920-ac05-b4db67870363", "metadata": {}, "source": [ "Increase the number of colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "75985f07-de92-4d8b-89d5-caf445b9375e", "metadata": {}, "outputs": [], "source": [ "sns.light_palette(\"xkcd:copper\", 8)" ] }, { "cell_type": "raw", "id": "34687ae8-fd6d-427a-a639-208f19e61122", "metadata": {}, "source": [ "Return a continuous colormap rather than a discrete palette:" ] }, { "cell_type": "code", "execution_count": null, "id": "2c342db4-7f97-40f5-934e-9a82201890d1", "metadata": {}, "outputs": [], "source": [ "sns.light_palette(\"#a275ac\", as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "e7ebe64b-25fa-4c52-9ebe-fdcbba0ee51e", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/lineplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ``flights`` dataset has 10 years of monthly airline passenger data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights = sns.load_dataset(\"flights\")\n", "flights.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To draw a line plot using long-form data, assign the ``x`` and ``y`` variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "may_flights = flights.query(\"month == 'May'\")\n", "sns.lineplot(data=may_flights, x=\"year\", y=\"passengers\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Pivot the dataframe to a wide-form representation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights_wide = flights.pivot(index=\"year\", columns=\"month\", values=\"passengers\")\n", "flights_wide.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To plot a single vector, pass it to ``data``. If the vector is a :class:`pandas.Series`, it will be plotted against its index:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=flights_wide[\"May\"])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Passing the entire wide-form dataset to ``data`` plots a separate line for each column:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=flights_wide)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Passing the entire dataset in long-form mode will aggregate over repeated values (each year) to show the mean and 95% confidence interval:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=flights, x=\"year\", y=\"passengers\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assign a grouping semantic (``hue``, ``size``, or ``style``) to plot separate lines" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=flights, x=\"year\", y=\"passengers\", hue=\"month\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The same column can be assigned to multiple semantic variables, which can increase the accessibility of the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=flights, x=\"year\", y=\"passengers\", hue=\"month\", style=\"month\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use the `orient` parameter to aggregate and sort along the vertical dimension of the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=flights, x=\"passengers\", y=\"year\", orient=\"y\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Each semantic variable can also represent a different column. For that, we'll need a more complex dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fmri = sns.load_dataset(\"fmri\")\n", "fmri.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Repeated observations are aggregated even when semantic grouping is used:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=fmri, x=\"timepoint\", y=\"signal\", hue=\"event\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assign both ``hue`` and ``style`` to represent two different grouping variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(data=fmri, x=\"timepoint\", y=\"signal\", hue=\"region\", style=\"event\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When assigning a ``style`` variable, markers can be used instead of (or along with) dashes to distinguish the groups:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(\n", " data=fmri,\n", " x=\"timepoint\", y=\"signal\", hue=\"event\", style=\"event\",\n", " markers=True, dashes=False\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Show error bars instead of error bands and extend them to two standard error widths:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(\n", " data=fmri, x=\"timepoint\", y=\"signal\", hue=\"event\", err_style=\"bars\", errorbar=(\"se\", 2),\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning the ``units`` variable will plot multiple lines without applying a semantic mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(\n", " data=fmri.query(\"region == 'frontal'\"),\n", " x=\"timepoint\", y=\"signal\", hue=\"event\", units=\"subject\",\n", " estimator=None, lw=1,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Load another dataset with a numeric grouping variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dots = sns.load_dataset(\"dots\").query(\"align == 'dots'\")\n", "dots.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a numeric variable to ``hue`` maps it differently, using a different default palette and a quantitative color mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(\n", " data=dots, x=\"time\", y=\"firing_rate\", hue=\"coherence\", style=\"choice\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Control the color mapping by setting the ``palette`` and passing a :class:`matplotlib.colors.Normalize` object:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(\n", " data=dots.query(\"coherence > 0\"),\n", " x=\"time\", y=\"firing_rate\", hue=\"coherence\", style=\"choice\",\n", " palette=\"flare\", hue_norm=mpl.colors.LogNorm(),\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Or pass specific colors, either as a Python list or dictionary:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "palette = sns.color_palette(\"mako_r\", 6)\n", "sns.lineplot(\n", " data=dots, x=\"time\", y=\"firing_rate\",\n", " hue=\"coherence\", style=\"choice\",\n", " palette=palette\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assign the ``size`` semantic to map the width of the lines with a numeric variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(\n", " data=dots, x=\"time\", y=\"firing_rate\",\n", " size=\"coherence\", hue=\"choice\",\n", " legend=\"full\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Pass a a tuple, ``sizes=(smallest, largest)``, to control the range of linewidths used to map the ``size`` semantic:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lineplot(\n", " data=dots, x=\"time\", y=\"firing_rate\",\n", " size=\"coherence\", hue=\"choice\",\n", " sizes=(.25, 2.5)\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, the observations are sorted by ``x``. Disable this to plot a line with the order that observations appear in the dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x, y = np.random.normal(size=(2, 5000)).cumsum(axis=1)\n", "sns.lineplot(x=x, y=y, sort=False, lw=1)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use :func:`relplot` to combine :func:`lineplot` and :class:`FacetGrid`. This allows grouping within additional categorical variables. Using :func:`relplot` is safer than using :class:`FacetGrid` directly, as it ensures synchronization of the semantic mappings across facets:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, x=\"timepoint\", y=\"signal\",\n", " col=\"region\", hue=\"event\", style=\"event\",\n", " kind=\"line\"\n", ")" ] } ], "metadata": { "interpreter": { "hash": "8bdfc9d9da1e36addfcfc8a3409187c45d33387af0f87d0d91e99e8d6403f1c3" }, "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/lmplot.ipynb ================================================ { "cells": [ { "cell_type": "raw", "id": "034a9a5b-91ff-4ccc-932d-0f314e2cd6d2", "metadata": {}, "source": [ "See the :func:`regplot` docs for demonstrations of various options for specifying the regression model, which are also accepted here." ] }, { "cell_type": "code", "execution_count": null, "id": "76c91243-3bd8-49a1-b8c8-b7272f09a3f1", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"ticks\")\n", "penguins = sns.load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "0ba9f55d-17ea-4084-a74f-852d51771380", "metadata": {}, "source": [ "Plot a regression fit over a scatter plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "2f789265-93c0-4867-b666-798713e4e7e5", "metadata": {}, "outputs": [], "source": [ "sns.lmplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "id": "7e4b0ad4-446c-4109-9393-961f76132e34", "metadata": {}, "source": [ "Condition the regression fit on another variable and represent it using color:" ] }, { "cell_type": "code", "execution_count": null, "id": "61347189-34e5-42ea-b77b-4acdef843326", "metadata": {}, "outputs": [], "source": [ "sns.lmplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\")" ] }, { "cell_type": "raw", "id": "c9b6d059-49dc-46a7-869b-86baa3a7ed65", "metadata": {}, "source": [ "Condition the regression fit on another variable and split across subplots:" ] }, { "cell_type": "code", "execution_count": null, "id": "d8ec2955-ccc9-493c-b9ec-c78648ce9f53", "metadata": {}, "outputs": [], "source": [ "sns.lmplot(\n", " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " hue=\"species\", col=\"sex\", height=4,\n", ")" ] }, { "cell_type": "raw", "id": "de01dee1-b2ce-445c-8d0d-d054ca0dfedb", "metadata": {}, "source": [ "Condition across two variables using both columns and rows:" ] }, { "cell_type": "code", "execution_count": null, "id": "6f1264aa-829c-416a-805a-b989e5f11a17", "metadata": {}, "outputs": [], "source": [ "sns.lmplot(\n", " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " col=\"species\", row=\"sex\", height=3,\n", ")" ] }, { "cell_type": "raw", "id": "b3888f04-b22f-4205-8acc-24ce5b59568e", "metadata": {}, "source": [ "Allow axis limits to vary across subplots:" ] }, { "cell_type": "code", "execution_count": null, "id": "67ed5af1-d228-4b81-b4f8-21937c513a10", "metadata": {}, "outputs": [], "source": [ "sns.lmplot(\n", " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " col=\"species\", row=\"sex\", height=3,\n", " facet_kws=dict(sharex=False, sharey=False),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "46e9cf18-c847-4c40-8e38-6c20cdde2be5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/move_legend.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "8ec46ad8-bc4c-4ee0-9626-271088c702f9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "penguins = sns.load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "008bdd98-88cb-4a81-9f50-9b0e5a357305", "metadata": {}, "source": [ "For axes-level functions, pass the :class:`matplotlib.axes.Axes` object and provide a new location." ] }, { "cell_type": "code", "execution_count": null, "id": "b82e58f9-b15d-4554-bee5-de6a689344a6", "metadata": {}, "outputs": [], "source": [ "ax = sns.histplot(penguins, x=\"bill_length_mm\", hue=\"species\")\n", "sns.move_legend(ax, \"center right\")" ] }, { "cell_type": "raw", "id": "4f2a7f5d-ab39-46c7-87f4-532e607adf0b", "metadata": {}, "source": [ "Use the `bbox_to_anchor` parameter for more fine-grained control, including moving the legend outside of the axes:" ] }, { "cell_type": "code", "execution_count": null, "id": "ed610a98-447a-4459-8342-48abc80330f0", "metadata": {}, "outputs": [], "source": [ "ax = sns.histplot(penguins, x=\"bill_length_mm\", hue=\"species\")\n", "sns.move_legend(ax, \"upper left\", bbox_to_anchor=(1, 1))" ] }, { "cell_type": "raw", "id": "9d2fd766-a806-45d9-949d-1572991cf512", "metadata": {}, "source": [ "Pass additional :meth:`matplotlib.axes.Axes.legend` parameters to update other properties:" ] }, { "cell_type": "code", "execution_count": null, "id": "5ad4342c-c46e-49e9-98a2-6c88c6fb4c54", "metadata": {}, "outputs": [], "source": [ "ax = sns.histplot(penguins, x=\"bill_length_mm\", hue=\"species\")\n", "sns.move_legend(\n", " ax, \"lower center\",\n", " bbox_to_anchor=(.5, 1), ncol=3, title=None, frameon=False,\n", ")" ] }, { "cell_type": "raw", "id": "0d573092-46fd-4a95-b7ed-7e6833823adc", "metadata": {}, "source": [ "It's also possible to move the legend created by a figure-level function. But when fine-tuning the position, you must bear in mind that the figure will have extra blank space on the right:" ] }, { "cell_type": "code", "execution_count": null, "id": "b258a9b8-69e5-4d4a-94cb-5b6baddc402b", "metadata": {}, "outputs": [], "source": [ "g = sns.displot(\n", " penguins,\n", " x=\"bill_length_mm\", hue=\"species\",\n", " col=\"island\", col_wrap=2, height=3,\n", ")\n", "sns.move_legend(g, \"upper left\", bbox_to_anchor=(.55, .45))" ] }, { "cell_type": "raw", "id": "c9dc54e2-2c66-412f-ab2a-4f2bc2cb5782", "metadata": {}, "source": [ "One way to avoid this would be to set `legend_out=False` on the :class:`FacetGrid`:" ] }, { "cell_type": "code", "execution_count": null, "id": "06cff408-4cdf-47af-8def-176f3e70ec5a", "metadata": {}, "outputs": [], "source": [ "g = sns.displot(\n", " penguins,\n", " x=\"bill_length_mm\", hue=\"species\",\n", " col=\"island\", col_wrap=2, height=3,\n", " facet_kws=dict(legend_out=False),\n", ")\n", "sns.move_legend(g, \"upper left\", bbox_to_anchor=(.55, .45), frameon=False)" ] }, { "cell_type": "code", "execution_count": null, "id": "b170f20d-22a9-4f7d-917a-d09e10b1f08c", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/mpl_palette.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "1d0d41d3-463c-4c6f-aa65-38131bdf3ddb", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "sns.palettes._patch_colormap_display()" ] }, { "cell_type": "markdown", "id": "d2a0ae1e-a01e-49b3-a677-2b05a195990a", "metadata": {}, "source": [ "Return discrete samples from a continuous matplotlib colormap:" ] }, { "cell_type": "code", "execution_count": null, "id": "2b6a4ce9-6e4e-4b59-ada8-14ef8aef21d7", "metadata": {}, "outputs": [], "source": [ "sns.mpl_palette(\"viridis\")" ] }, { "cell_type": "raw", "id": "0ccc47b1-c969-46e2-93bb-b9eb5a2e2141", "metadata": {}, "source": [ "Return the continuous colormap instead; note how the extreme values are more intense:" ] }, { "cell_type": "code", "execution_count": null, "id": "a8a1bc5d-1d62-45c6-a53b-9fadb58f11c0", "metadata": {}, "outputs": [], "source": [ "sns.mpl_palette(\"viridis\", as_cmap=True)" ] }, { "cell_type": "raw", "id": "ff0d1a3b-8641-40c0-bb4b-c22b83ec9432", "metadata": {}, "source": [ "Return more colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "8faef1d8-a1eb-4060-be10-377342c9bd1d", "metadata": {}, "outputs": [], "source": [ "sns.mpl_palette(\"viridis\", 8)" ] }, { "cell_type": "raw", "id": "612bf052-e888-411d-a2ea-6a742a78bc63", "metadata": {}, "source": [ "Return values from a qualitative colormap:" ] }, { "cell_type": "code", "execution_count": null, "id": "74db95a8-4898-4f6c-a57d-c751af1dc7bf", "metadata": {}, "outputs": [], "source": [ "sns.mpl_palette(\"Set2\")" ] }, { "cell_type": "raw", "id": "918494bf-1b8e-4b00-8950-1bd73032dee1", "metadata": {}, "source": [ "Notice how the palette will only contain distinct colors and can be shorter than requested:" ] }, { "cell_type": "code", "execution_count": null, "id": "d97efa25-9050-4e28-b758-da6f43c9f963", "metadata": {}, "outputs": [], "source": [ "sns.mpl_palette(\"Set2\", 10)" ] }, { "cell_type": "code", "execution_count": null, "id": "f64ad118-e213-43cc-a714-98ed13cc3824", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Agg.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "0d053943-66c9-410d-ad65-ce91f1c1ff48", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "diamonds = load_dataset(\"diamonds\")" ] }, { "cell_type": "raw", "id": "51b029af-b83b-4ae0-a6ff-f48bf9692518", "metadata": {}, "source": [ "The default behavior is to aggregate by taking a mean over each group:" ] }, { "cell_type": "code", "execution_count": null, "id": "28451b4e-9f4e-4604-b2b9-6138c4f51436", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(diamonds, \"clarity\", \"carat\")\n", "p.add(so.Bar(), so.Agg())" ] }, { "cell_type": "raw", "id": "53859a3b-051c-423d-97ef-b03f647268b7", "metadata": {}, "source": [ "Other aggregation functions can be selected by name if they are pandas methods:" ] }, { "cell_type": "code", "execution_count": null, "id": "5beaac3a-b9f7-4acc-81c7-480599e3675e", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bar(), so.Agg(\"median\"))" ] }, { "cell_type": "raw", "id": "2d318ee3-56c1-4fd4-99a5-fa87db770f67", "metadata": {}, "source": [ "It's also possible to pass an arbitrary aggregation function:" ] }, { "cell_type": "code", "execution_count": null, "id": "bd11e289-7274-464a-b781-06fb756cf8de", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bar(), so.Agg(lambda x: x.quantile(.75) - x.quantile(.25)))" ] }, { "cell_type": "raw", "id": "555394c1-25f8-4932-94d1-f67a8a9fa1c6", "metadata": {}, "source": [ "When other mapping variables are assigned, they'll be used to define aggregation groups. With some marks, it may be helpful to use additional transforms, such as :class:`Dodge`:" ] }, { "cell_type": "code", "execution_count": null, "id": "5755cdeb-1d1a-4434-9cc5-91024735eb4e", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bar(), so.Agg(), so.Dodge(), color=\"cut\")" ] }, { "cell_type": "raw", "id": "07eb1150-db57-4a58-b830-8a7aba9f46ec", "metadata": {}, "source": [ "The variable that gets aggregated depends on the orientation of the layer, which is usually inferred from the coordinate variable types (but may also be specified with the `orient` parameter in :meth:`Plot.add`):" ] }, { "cell_type": "code", "execution_count": null, "id": "1bdcc970-1b6c-4a3d-b0bc-6c7a625163ff", "metadata": {}, "outputs": [], "source": [ "so.Plot(diamonds, \"carat\", \"clarity\").add(so.Bar(), so.Agg())" ] }, { "cell_type": "code", "execution_count": null, "id": "ad8006ff-5472-4345-9537-a5680c519f4f", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Area.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "healthexp = (\n", " load_dataset(\"healthexp\")\n", " .pivot(index=\"Year\", columns=\"Country\", values=\"Spending_USD\")\n", " .interpolate()\n", " .stack()\n", " .rename(\"Spending_USD\")\n", " .reset_index()\n", " .sort_values(\"Country\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6d3bc7fe-0b0b-49eb-8f8b-ddd8c7441044", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(healthexp, \"Year\", \"Spending_USD\").facet(\"Country\", wrap=3)\n", "p.add(so.Area())" ] }, { "cell_type": "raw", "id": "3a47b7f1-31ef-4218-a1ea-c289f3c64ab5", "metadata": {}, "source": [ "The `color` property sets both the edge and fill color:" ] }, { "cell_type": "code", "execution_count": null, "id": "1697359a-bf26-49d0-891b-49c207cab82d", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), color=\"Country\")" ] }, { "cell_type": "raw", "id": "9bfaed37-7153-45d9-89e5-b348c7c14401", "metadata": {}, "source": [ "It's also possible to map only the `edgecolor`:" ] }, { "cell_type": "code", "execution_count": null, "id": "39e5c9e5-793e-450c-a5d2-e09d5ad1f854", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(color=\".5\", edgewidth=2), edgecolor=\"Country\")" ] }, { "cell_type": "raw", "id": "0b1a5297-9e96-472d-b284-919048e41358", "metadata": {}, "source": [ "The mark is drawn as a polygon, but it can be combined with :class:`Line` to draw a shaded region by setting `edgewidth=0`:" ] }, { "cell_type": "code", "execution_count": null, "id": "42b65535-acf6-4634-84bd-6e35305e3018", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(edgewidth=0)).add(so.Line())" ] }, { "cell_type": "raw", "id": "59761f97-eadb-4047-9e6b-09339545fe57", "metadata": {}, "source": [ "The layer's orientation defines the axis that the mark fills from:" ] }, { "cell_type": "code", "execution_count": null, "id": "a1c30f88-6287-486d-ae4b-fc272bc8e6ab", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), x=\"Spending_USD\", y=\"Year\", orient=\"y\")" ] }, { "cell_type": "raw", "id": "f1b893c5-6847-4e5b-9fc2-4190ddd75099", "metadata": {}, "source": [ "This mark can be stacked to show part-whole relationships:" ] }, { "cell_type": "code", "execution_count": null, "id": "66a79e6e-3e7f-4f54-9394-f8b003a0e228", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(healthexp, \"Year\", \"Spending_USD\", color=\"Country\")\n", " .add(so.Area(alpha=.7), so.Stack())\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "69f4e423-94f4-4003-b337-12162d1040c2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Band.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "fmri = load_dataset(\"fmri\").query(\"region == 'parietal'\")\n", "seaice = (\n", " load_dataset(\"seaice\")\n", " .assign(\n", " Day=lambda x: x[\"Date\"].dt.day_of_year,\n", " Year=lambda x: x[\"Date\"].dt.year,\n", " )\n", " .query(\"Year >= 1980\")\n", " .astype({\"Year\": str})\n", " .pivot(index=\"Day\", columns=\"Year\", values=\"Extent\")\n", " .filter([\"1980\", \"2019\"])\n", " .dropna()\n", " .reset_index()\n", ")" ] }, { "cell_type": "raw", "id": "e840e876-fbd6-4bfd-868c-a9d7af7913fa", "metadata": {}, "source": [ "The mark fills between pairs of data points to show an interval on the value axis:" ] }, { "cell_type": "code", "execution_count": null, "id": "518cf20d-bb0b-433a-9b25-f1ed8d432149", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(seaice, x=\"Day\", ymin=\"1980\", ymax=\"2019\")\n", "p.add(so.Band())" ] }, { "cell_type": "raw", "id": "fa50b778-13f9-4368-a967-68365fd51117", "metadata": {}, "source": [ "By default it draws a faint ribbon with no edges, but edges can be added:" ] }, { "cell_type": "code", "execution_count": null, "id": "a05176c4-0615-49ca-a2df-48ced8b5a8a8", "metadata": {}, "outputs": [], "source": [ "p.add(so.Band(alpha=.5, edgewidth=2))" ] }, { "cell_type": "raw", "id": "776d192a-f35f-4253-be7f-01e4b2466dad", "metadata": {}, "source": [ "The defaults are optimized for the main expected usecase, where the mark is combined with a line to show an errorbar interval:" ] }, { "cell_type": "code", "execution_count": null, "id": "69f4e423-94f4-4003-b337-12162d1040c2", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", " .add(so.Band(), so.Est())\n", " .add(so.Line(), so.Agg())\n", ")" ] }, { "cell_type": "raw", "id": "9f0c82bf-3457-4ac5-ba48-8930bac03d75", "metadata": {}, "source": [ "When min/max values are not explicitly assigned or added in a transform, the band will cover the full extent of the data:" ] }, { "cell_type": "code", "execution_count": null, "id": "309f578e-da3d-4dc5-b6ac-a354321334c8", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(fmri, x=\"timepoint\", y=\"signal\", color=\"event\")\n", " .add(so.Line(linewidth=.5), group=\"subject\")\n", " .add(so.Band())\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "4330a3cd-63fe-470a-8e83-09e9606643b5", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Bar.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")\n", "flights = load_dataset(\"flights\").query(\"year == 1960\")" ] }, { "cell_type": "raw", "id": "4e817cdd-09a3-4cf6-8602-e9665607bfe1", "metadata": {}, "source": [ "The mark draws discrete bars from a baseline to provided values:" ] }, { "cell_type": "code", "execution_count": null, "id": "5a4e5ba1-50ce-4060-8eb7-f17fee9080c0", "metadata": {}, "outputs": [], "source": [ "so.Plot(flights[\"month\"], flights[\"passengers\"]).add(so.Bar())" ] }, { "cell_type": "raw", "id": "252cf7b2-7fc8-4085-8174-0126743d8a08", "metadata": {}, "source": [ "The bars are oriented depending on the x/y variable types and the `orient` parameter:" ] }, { "cell_type": "code", "execution_count": null, "id": "81dbbc81-178a-46dd-9acf-2c57d2a7e315", "metadata": {}, "outputs": [], "source": [ "so.Plot(flights[\"passengers\"], flights[\"month\"]).add(so.Bar())" ] }, { "cell_type": "markdown", "id": "6fddeceb-25b9-4fc1-bae0-4cc4cb612674", "metadata": {}, "source": [ "A common usecase will be drawing histograms on a variable with a nominal scale:" ] }, { "cell_type": "code", "execution_count": null, "id": "08604543-c681-4cd3-943e-b57c0f863b2e", "metadata": {}, "outputs": [], "source": [ "so.Plot(penguins, x=\"species\").add(so.Bar(), so.Hist())" ] }, { "cell_type": "markdown", "id": "8b9af978-fdb0-46aa-9cf9-d3e49e38b344", "metadata": {}, "source": [ "When mapping additional variables, the bars will overlap by default:" ] }, { "cell_type": "code", "execution_count": null, "id": "297f7fef-7c31-40dd-ac68-e0ce7f131528", "metadata": {}, "outputs": [], "source": [ "so.Plot(penguins, x=\"species\", color=\"sex\").add(so.Bar(), so.Hist())" ] }, { "cell_type": "raw", "id": "cd9b7b4a-3150-42b5-b1a8-1c5950ca8703", "metadata": {}, "source": [ "Apply a move transform, such as a :class:`Dodge` or :class:`Stack` to resolve them:" ] }, { "cell_type": "code", "execution_count": null, "id": "a13c7594-737c-4215-b2a2-e59fc2d033c3", "metadata": {}, "outputs": [], "source": [ "so.Plot(penguins, x=\"species\", color=\"sex\").add(so.Bar(), so.Hist(), so.Dodge())" ] }, { "cell_type": "raw", "id": "f5f44a6b-610a-4523-a7c2-39c804a60520", "metadata": {}, "source": [ "A number of properties can be mapped or set:" ] }, { "cell_type": "code", "execution_count": null, "id": "e5cbf5a9-effb-4550-bdaf-c266dc69d3f0", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(\n", " penguins, x=\"species\",\n", " color=\"sex\", alpha=\"sex\", edgestyle=\"sex\",\n", " )\n", " .add(so.Bar(edgewidth=2), so.Hist(), so.Dodge(\"fill\"))\n", ")" ] }, { "cell_type": "raw", "id": "539144d9-75bc-4eb0-8fed-ca57b516b6d3", "metadata": {}, "source": [ "Combine with :class:`Range` to plot an estimate with errorbars:" ] }, { "cell_type": "code", "execution_count": null, "id": "89233c4a-38e7-4807-b3b4-3b4540ffcf56", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, \"body_mass_g\", \"species\", color=\"sex\")\n", " .add(so.Bar(alpha=.5), so.Agg(), so.Dodge())\n", " .add(so.Range(), so.Est(errorbar=\"sd\"), so.Dodge())\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "4f6a97a0-2d92-4fd5-ad98-b4299bda1b6b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Bars.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "diamonds = load_dataset(\"diamonds\")" ] }, { "cell_type": "raw", "id": "5cf83822-ceb1-4ce5-8364-069466f7aa40", "metadata": {}, "source": [ "This mark draws bars between a baseline and a value. In contrast to :class:`Bar`, the bars have a full width and thin edges by default; this makes this mark a better choice for a continuous histogram:" ] }, { "cell_type": "code", "execution_count": null, "id": "e9b99eaf-695f-41ae-9bd1-bfe406dedb63", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(diamonds, \"price\").scale(x=\"log\")\n", "p.add(so.Bars(), so.Hist())" ] }, { "cell_type": "raw", "id": "bc4c0f25-3f7a-4a2c-a032-151da47f5ea3", "metadata": {}, "source": [ "When mapping the color or other properties, bars will overlap by default; this is usually confusing:" ] }, { "cell_type": "code", "execution_count": null, "id": "7989211b-7a29-4763-bb97-4ea19cdef081", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(), color=\"cut\")" ] }, { "cell_type": "raw", "id": "f16a3b5d-1ac1-4d9d-9bc6-d4cea7f83a17", "metadata": {}, "source": [ "Using a move transform, such as :class:`Stack` or :class:`Dodge`, will resolve the overlap (although faceting might often be a better approach):" ] }, { "cell_type": "code", "execution_count": null, "id": "8933f5f7-1423-4741-b7be-6239ea8b2fee", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(), so.Stack(), color=\"cut\")" ] }, { "cell_type": "raw", "id": "74075e80-0361-4388-a459-cbfa6418df6c", "metadata": {}, "source": [ "A number of different properties can be set or mapped:" ] }, { "cell_type": "code", "execution_count": null, "id": "04fada68-a61b-451c-b3bd-9aaab16b5f29", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(edgewidth=0), so.Hist(), so.Stack(), alpha=\"clarity\")" ] }, { "cell_type": "raw", "id": "a14d7d36-9d8b-4024-8653-002e9da946d7", "metadata": {}, "source": [ "It is possible to draw unfilled bars, but you must override the default edge color:" ] }, { "cell_type": "code", "execution_count": null, "id": "21642f8c-99c7-4f61-b3f5-bc1dacc638c3", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(fill=False, edgecolor=\"C0\", edgewidth=1.5), so.Hist())" ] }, { "cell_type": "raw", "id": "dce5b6cc-0808-48ec-b4d6-0c0c2e5178d2", "metadata": {}, "source": [ "It is also possible to narrow the bars, which may be useful for dealing with overlap in some cases:" ] }, { "cell_type": "code", "execution_count": null, "id": "166693bf-420c-4ec3-8da2-abc22724952b", "metadata": {}, "outputs": [], "source": [ "hist = so.Hist(binwidth=.075, binrange=(2, 5))\n", "(\n", " p.add(so.Bars(), hist)\n", " .add(\n", " so.Bars(color=\".9\", width=.5), hist,\n", " data=diamonds.query(\"cut == 'Ideal'\")\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "b40b02c4-fb2c-4300-93e4-24ea28bc6ef8", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Count.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "89113d6b-70b9-4ebe-9910-10a80eab246e", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "tips = load_dataset(\"tips\")" ] }, { "cell_type": "raw", "id": "daf6ff78-df24-4541-ba72-73fb9eddb50d", "metadata": {}, "source": [ "The transform counts distinct observations of the orientation variable defines a new variable on the opposite axis:" ] }, { "cell_type": "code", "execution_count": null, "id": "390f2fd3-0596-40e3-b262-163b3a90d055", "metadata": {}, "outputs": [], "source": [ "so.Plot(tips, x=\"day\").add(so.Bar(), so.Count())" ] }, { "cell_type": "raw", "id": "813fb4a5-db68-4b51-b236-5b5628ebba47", "metadata": {}, "source": [ "When additional mapping variables are defined, they are also used to define groups:" ] }, { "cell_type": "code", "execution_count": null, "id": "76a4ae70-e914-4f54-b979-ce1b79374fc3", "metadata": {}, "outputs": [], "source": [ "so.Plot(tips, x=\"day\", color=\"sex\").add(so.Bar(), so.Count(), so.Dodge())" ] }, { "cell_type": "raw", "id": "2973dee1-5aee-4768-846d-22d220faf170", "metadata": {}, "source": [ "Unlike :class:`Hist`, numeric data are not binned before counting:" ] }, { "cell_type": "code", "execution_count": null, "id": "6f94c5f0-680e-4d8a-a1c9-70876980dd1c", "metadata": {}, "outputs": [], "source": [ "so.Plot(tips, x=\"size\").add(so.Bar(), so.Count())" ] }, { "cell_type": "raw", "id": "11acd5e6-f477-4eb1-b1d7-72f4582bca45", "metadata": {}, "source": [ "When the `y` variable is defined, the counts are assigned to the `x` variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "924e0e35-210f-4f65-83b4-4aebe41ad264", "metadata": {}, "outputs": [], "source": [ "so.Plot(tips, y=\"size\").add(so.Bar(), so.Count())" ] }, { "cell_type": "code", "execution_count": null, "id": "0229fa39-b6dc-48da-9a25-31e25ed34ebc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Dash.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "3227e585-7166-44e7-b0c2-8570e098102d", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "1b424322-eaa4-45c7-8007-a671ef2afbde", "metadata": {}, "source": [ "A line segment is drawn for each datapoint, centered on the value along the orientation axis:" ] }, { "cell_type": "code", "execution_count": null, "id": "fc835356-2dc2-4583-a9f9-c1fe0a6cc9ea", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(penguins, \"species\", \"body_mass_g\", color=\"sex\")\n", "p.add(so.Dash())" ] }, { "cell_type": "raw", "id": "ad9b94de-f19f-4e60-8275-686e749da39c", "metadata": {}, "source": [ "A number of properties can be mapped or set directly:" ] }, { "cell_type": "code", "execution_count": null, "id": "6070a665-ab19-43a6-9eba-e206193d9422", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dash(alpha=.5), linewidth=\"flipper_length_mm\")" ] }, { "cell_type": "raw", "id": "2c4a8291-0a84-4e70-a992-756850933791", "metadata": {}, "source": [ "The mark has a `width` property, which is relative to the spacing between orientation values:" ] }, { "cell_type": "code", "execution_count": null, "id": "315327da-421e-46c8-8a1b-8b87355d0439", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dash(width=.5))" ] }, { "cell_type": "raw", "id": "224bf51a-b8d8-4d8e-b0ab-b63ec6788584", "metadata": {}, "source": [ "When dodged, the width will automatically adapt:" ] }, { "cell_type": "code", "execution_count": null, "id": "227e889c-7ce7-49fc-b985-f7746393930e", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dash(), so.Dodge())" ] }, { "cell_type": "raw", "id": "aa807f57-5d37-4faa-8fd2-1e5378115f9f", "metadata": {}, "source": [ "This mark works well to show aggregate values when paired with a strip plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "5141e0b8-ea1a-4178-adde-21b4bc2e705f", "metadata": {}, "outputs": [], "source": [ "(\n", " p\n", " .add(so.Dash(), so.Agg(), so.Dodge())\n", " .add(so.Dots(), so.Dodge(), so.Jitter())\n", ")" ] }, { "cell_type": "raw", "id": "f2abd4b7-5afb-4661-95f3-b51bfa101273", "metadata": {}, "source": [ "When both coordinate variables are numeric, you can control the orientation explicitly:" ] }, { "cell_type": "code", "execution_count": null, "id": "f6d7e236-327f-460f-b12e-46d7444ac348", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(\n", " penguins[\"body_mass_g\"],\n", " penguins[\"flipper_length_mm\"].round(-1),\n", " )\n", " .add(so.Dash(), orient=\"y\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6811d776-93e5-49ce-88a6-14786a67841d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Dodge.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "4d44a940-db84-4e16-bc83-e67d08d6d56a", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "tips = load_dataset(\"tips\").astype({\"time\": str})" ] }, { "cell_type": "raw", "id": "ce99e1a1-c213-478f-a5bc-d19e2c4d70db", "metadata": {}, "source": [ "This transform modifies both the width and position (along the orientation axis) of marks that would otherwise overlap:" ] }, { "cell_type": "code", "execution_count": null, "id": "f6a84062-2c2b-4a45-91cb-77f29462104d", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, \"day\", color=\"time\")\n", " .add(so.Bar(), so.Count(), so.Dodge())\n", ")" ] }, { "cell_type": "raw", "id": "55d3a9a8-c973-4e91-9f3a-bc137df15f48", "metadata": {}, "source": [ "By default, empty space may appear when variables are not fully crossed:" ] }, { "cell_type": "code", "execution_count": null, "id": "08ae1c65-5ad9-47a3-a8f3-d901bd4821f2", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(tips, \"day\", color=\"time\")\n", "p.add(so.Bar(), so.Count(), so.Dodge())" ] }, { "cell_type": "raw", "id": "2125f07d-4210-4d49-8761-bcfa3f9c67f5", "metadata": {}, "source": [ "The `empty` parameter handles this case; use it to fill out the space:" ] }, { "cell_type": "code", "execution_count": null, "id": "c2314343-de73-45d7-9595-acf5f7d62e93", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bar(), so.Count(), so.Dodge(empty=\"fill\"))" ] }, { "cell_type": "raw", "id": "08f4382c-842e-4777-a452-1d88251da6e7", "metadata": {}, "source": [ "Or center the marks while using a consistent width:" ] }, { "cell_type": "code", "execution_count": null, "id": "1e0745e4-be11-4703-bf9c-4b13cbb76e91", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bar(), so.Count(), so.Dodge(empty=\"drop\"))" ] }, { "cell_type": "raw", "id": "7d29ec53-caef-4cff-9828-dc242adb5c49", "metadata": {}, "source": [ "Use `gap` to add a bit of spacing between dodged marks:" ] }, { "cell_type": "code", "execution_count": null, "id": "342aca16-c67b-4bc4-9101-fec6c398aa0f", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(tips, \"day\", \"total_bill\", color=\"sex\")\n", "p.add(so.Bar(), so.Agg(\"sum\"), so.Dodge(gap=.1))" ] }, { "cell_type": "raw", "id": "68b52dcb-c5e7-4186-b61f-e96fac5f4d40", "metadata": {}, "source": [ "When multiple semantic variables are used, each distinct group will be dodged:" ] }, { "cell_type": "code", "execution_count": null, "id": "497f3e3b-39bc-4381-85bb-be5bb5c60b1f", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dot(), so.Dodge(), fill=\"smoker\")" ] }, { "cell_type": "raw", "id": "795835d2-904f-4343-89c2-b91be9c1c504", "metadata": {}, "source": [ "Use `by` to dodge only a subset of variables:" ] }, { "cell_type": "code", "execution_count": null, "id": "da01f6c0-c425-409c-a010-5cb52a794dc9", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dot(), so.Dodge(by=[\"color\"]), fill=\"smoker\")" ] }, { "cell_type": "raw", "id": "77de77da-2fad-4374-9d14-90520e448c90", "metadata": {}, "source": [ "When combining with other transforms (such as :class:`Jitter` or :class:`Stack`), be mindful of the order that they are applied in:" ] }, { "cell_type": "code", "execution_count": null, "id": "29ccabd6-6bd5-4563-a337-f8f8d25f7dad", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dot(), so.Dodge(), so.Jitter())" ] }, { "cell_type": "code", "execution_count": null, "id": "a73fe9a5-c717-41fd-874e-be72334ea6d4", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Dot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "tips = load_dataset(\"tips\")\n", "glue = load_dataset(\"glue\")" ] }, { "cell_type": "raw", "id": "f8e7b343-0301-49b3-8d42-862266d322bb", "metadata": {}, "source": [ "This mark draws relatively large, filled dots by default:" ] }, { "cell_type": "code", "execution_count": null, "id": "f92e97d0-b6a5-41ec-8507-dc64e60cb6e0", "metadata": {}, "outputs": [], "source": [ "p1 = so.Plot(tips, \"total_bill\", \"tip\")\n", "p1.add(so.Dot())" ] }, { "cell_type": "raw", "id": "625abe2a-7b0b-42a7-bfbc-dc2bfaf14897", "metadata": {}, "source": [ "While :class:`Dots` is a better choice for dense scatter plots, adding a thin edge can help to resolve individual points:" ] }, { "cell_type": "code", "execution_count": null, "id": "a3c7c22d-c7ce-40a9-941b-a8bc30db1e54", "metadata": {}, "outputs": [], "source": [ "p1.add(so.Dot(edgecolor=\"w\"))" ] }, { "cell_type": "markdown", "id": "398a43e1-4d45-42ea-bc87-41a8602540a4", "metadata": {}, "source": [ "Dodging and jittering can also help to reduce overplotting, when appropriate:" ] }, { "cell_type": "code", "execution_count": null, "id": "1b15e393-35cf-457f-8180-d92d05e2675a", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, \"total_bill\", \"day\", color=\"sex\")\n", " .add(so.Dot(), so.Dodge(), so.Jitter(.2))\n", ")" ] }, { "cell_type": "raw", "id": "12453ada-40e6-4aad-9f32-ba41fd7b27ca", "metadata": {}, "source": [ "The larger dot size makes this mark well suited to representing values along a nominal scale:" ] }, { "cell_type": "code", "execution_count": null, "id": "bd2edac0-ee6b-4cc9-8201-641b589630b8", "metadata": {}, "outputs": [], "source": [ "p2 = so.Plot(glue, \"Score\", \"Model\").facet(\"Task\", wrap=4).limit(x=(-5, 105))\n", "p2.add(so.Dot())" ] }, { "cell_type": "raw", "id": "ddd86209-d5cd-4f7a-9274-c578bc6a9f07", "metadata": {}, "source": [ "A number of properties can be set or mapped:" ] }, { "cell_type": "code", "execution_count": null, "id": "d00cdc35-4b9c-4f32-a047-8e036e565c4f", "metadata": {}, "outputs": [], "source": [ "(\n", " p2\n", " .add(so.Dot(pointsize=6), color=\"Year\", marker=\"Encoder\")\n", " .scale(marker=[\"o\", \"s\"], color=\"flare\")\n", ")" ] }, { "cell_type": "raw", "id": "061e22f4-8505-425d-8c80-8ac82c6a3125", "metadata": {}, "source": [ "Note that the edge properties are parameterized differently for filled and unfilled markers; use `stroke` and `color` rather than `edgewidth` and `edgecolor` if the marker is unfilled:" ] }, { "cell_type": "code", "execution_count": null, "id": "964b00be-1c29-4664-838d-0daeead9154a", "metadata": {}, "outputs": [], "source": [ "p2.add(so.Dot(stroke=1.5), fill=\"Encoder\", color=\"Encoder\")" ] }, { "cell_type": "raw", "id": "fb5e1383-1460-4389-a67b-09ec7965af90", "metadata": {}, "source": [ "Combine with :class:`Range` to show error bars:" ] }, { "cell_type": "code", "execution_count": null, "id": "b2618c22-bc7f-4ddd-9824-346e8d9b2b51", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"day\")\n", " .add(so.Dot(pointsize=3), so.Shift(y=.2), so.Jitter(.2))\n", " .add(so.Dot(), so.Agg())\n", " .add(so.Range(), so.Est(errorbar=(\"se\", 2)))\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "e5dc04fd-dba4-4b86-99a1-31ba00c7650d", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Dots.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "mpg = load_dataset(\"mpg\")" ] }, { "cell_type": "raw", "id": "f8e7b343-0301-49b3-8d42-862266d322bb", "metadata": {}, "source": [ "This mark draws relatively small, partially-transparent dots:" ] }, { "cell_type": "code", "execution_count": null, "id": "d668d7f6-555b-4b5d-876e-35e259076d2a", "metadata": {}, "outputs": [], "source": [ "p1 = so.Plot(mpg, \"horsepower\", \"mpg\")\n", "p1.add(so.Dots())" ] }, { "cell_type": "raw", "id": "a2cf4669-9c91-4adc-9e3a-3b0660e7898e", "metadata": {}, "source": [ "Fixing or mapping the `color` property changes both the stroke (edge) and fill:" ] }, { "cell_type": "code", "execution_count": null, "id": "bba2b1c5-22fd-4f44-af8d-defb31dfbe9d", "metadata": {}, "outputs": [], "source": [ "p1.add(so.Dots(), color=\"origin\")" ] }, { "cell_type": "raw", "id": "bf967d57-22cf-4bce-b718-aae6936719e6", "metadata": {}, "source": [ "These properties can be independently parametrized (although the resulting plot may not always be clear):" ] }, { "cell_type": "code", "execution_count": null, "id": "c45261a9-fb88-4eb5-b633-060debda261b", "metadata": {}, "outputs": [], "source": [ "(\n", " p1.add(so.Dots(fillalpha=.5), color=\"origin\", fillcolor=\"weight\")\n", " .scale(fillcolor=\"binary\")\n", ")" ] }, { "cell_type": "raw", "id": "b20dcaee-8e09-4a76-8eff-5289ef43ea8c", "metadata": {}, "source": [ "Filled and unfilled markers will happily mix:" ] }, { "cell_type": "code", "execution_count": null, "id": "a1a9bdda-abb7-4850-a936-ceed518b9b17", "metadata": {}, "outputs": [], "source": [ "p1.add(so.Dots(stroke=1), marker=\"origin\").scale(marker=[\"o\", \"x\", (6, 2, 1)])" ] }, { "cell_type": "raw", "id": "1d932f10-e8f8-4114-9362-3da82c7b5ac0", "metadata": {}, "source": [ "The partial opacity also helps to see local density when using jitter:" ] }, { "cell_type": "code", "execution_count": null, "id": "692e1611-4804-4979-b616-041e9fa9cdd9", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg, \"horsepower\", \"origin\")\n", " .add(so.Dots(), so.Jitter(.25))\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "acd5788f-e62b-497c-a109-f0bc02b8cae9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Est.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "57ececfa-0ae0-4acb-b85d-7c6a6ca8d3db", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "diamonds = load_dataset(\"diamonds\")" ] }, { "cell_type": "raw", "id": "03c64256-8daf-4b32-87bd-b425e27a7823", "metadata": {}, "source": [ "The default behavior is to compute the mean and 95% confidence interval (using bootstrapping):" ] }, { "cell_type": "code", "execution_count": null, "id": "46017dc7-7c3c-4dcf-9232-2e3ac490d980", "metadata": { "tags": [] }, "outputs": [], "source": [ "p = so.Plot(diamonds, \"clarity\", \"carat\")\n", "p.add(so.Range(), so.Est())" ] }, { "cell_type": "raw", "id": "1bf04e8d-998e-4a47-9375-ddcde76e3914", "metadata": {}, "source": [ "Other estimators may be selected by name if they are pandas methods:" ] }, { "cell_type": "code", "execution_count": null, "id": "ea394c55-8fa6-4fb0-8665-42c03ef3576e", "metadata": {}, "outputs": [], "source": [ "p.add(so.Range(), so.Est(\"median\"))" ] }, { "cell_type": "raw", "id": "9c5f3c91-fecb-4e75-b045-b30870154083", "metadata": {}, "source": [ "There are several options for computing the error bar interval, such as (scaled) standard errors:" ] }, { "cell_type": "code", "execution_count": null, "id": "9c350af5-d549-4cce-b3f2-e9bef33aef36", "metadata": {}, "outputs": [], "source": [ "p.add(so.Range(), so.Est(errorbar=\"se\"))" ] }, { "cell_type": "raw", "id": "8c8d321b-5e73-418c-8c71-4b91cf187e57", "metadata": {}, "source": [ "The error bars can also represent the spread of the distribution around the estimate using (scaled) standard deviations:" ] }, { "cell_type": "code", "execution_count": null, "id": "fd2cd9dc-e4c9-4ba1-ac79-38806cf1e009", "metadata": {}, "outputs": [], "source": [ "p.add(so.Range(), so.Est(errorbar=\"sd\"))" ] }, { "cell_type": "raw", "id": "6dba074b-881c-40df-b42e-458e4a26e23d", "metadata": {}, "source": [ "Because confidence intervals are computed using bootstrapping, there will be small amounts of randomness. Reduce the random variability by increasing the nubmer of bootstrap iterations (although this will be slower), or eliminate it by seeding the random number generator:" ] }, { "cell_type": "code", "execution_count": null, "id": "d6b450e1-8b1f-411f-aa01-bbb46ab3b6ec", "metadata": {}, "outputs": [], "source": [ "p.add(so.Range(), so.Est(seed=0))" ] }, { "cell_type": "markdown", "id": "df807ef8-b5fb-4eac-b539-1bd4e797ddc2", "metadata": {}, "source": [ "To compute a weighted estimate (and confidence interval), assign a `weight` variable in the layer where you use the stat:" ] }, { "cell_type": "code", "execution_count": null, "id": "5e4a0594-e1ee-4f72-971e-3763dd626e8b", "metadata": {}, "outputs": [], "source": [ "p.add(so.Range(), so.Est(), weight=\"price\")" ] }, { "cell_type": "code", "execution_count": null, "id": "0d0c34d7-fb76-44cf-9079-3ec7f45741d0", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Hist.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "59690096-a0ad-4ff3-b82c-0258d724035a", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "c345a35c-bac8-4163-ba40-e7c208df1033", "metadata": {}, "source": [ "For discrete or categorical variables, this stat is commonly combined with a :class:`Bar` mark:" ] }, { "cell_type": "code", "execution_count": null, "id": "6a96ac9b-1240-496d-9385-840205945208", "metadata": {}, "outputs": [], "source": [ "so.Plot(penguins, \"island\").add(so.Bar(), so.Hist())" ] }, { "cell_type": "raw", "id": "1e5ff9d5-c6a9-4adc-a9be-0f155b1575be", "metadata": {}, "source": [ "When used to estimate a univariate distribution, it is better to use the :class:`Bars` mark:" ] }, { "cell_type": "code", "execution_count": null, "id": "7f3e3144-752a-4d71-9528-85eb1ed0a9a4", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(penguins, \"flipper_length_mm\")\n", "p.add(so.Bars(), so.Hist())" ] }, { "cell_type": "raw", "id": "008b9ffe-da74-4406-9756-4f70e333f33b", "metadata": {}, "source": [ "The granularity of the bins will influence whether the underlying distribution is accurately represented. Adjust it by setting the total number:" ] }, { "cell_type": "code", "execution_count": null, "id": "27d221d5-add5-40a8-85d2-05102384dad1", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(bins=20))" ] }, { "cell_type": "raw", "id": "fffebb54-0299-45c5-b7fb-6fcad6427239", "metadata": {}, "source": [ "Alternatively, specify the *width* of the bins:" ] }, { "cell_type": "code", "execution_count": null, "id": "d036ca65-7dcf-45ac-a2d1-caafb9f922a7", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(binwidth=5))" ] }, { "cell_type": "raw", "id": "bc1e4bd3-2a16-42bd-9c13-a660dd381f66", "metadata": {}, "source": [ "By default, the transform returns the count of observations in each bin. The counts can be normalized, e.g. to show a proportion:" ] }, { "cell_type": "code", "execution_count": null, "id": "dbf23712-2231-4226-8265-0e2a5299c4bb", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(stat=\"proportion\"))" ] }, { "cell_type": "raw", "id": "6c6fb23e-78c5-4630-a958-62cb4dee4ec8", "metadata": {}, "source": [ "When additional variables define groups, the default behavior is to normalize across all groups:" ] }, { "cell_type": "code", "execution_count": null, "id": "ac3fe4ef-56e3-4ec7-b580-596d2a3d924b", "metadata": {}, "outputs": [], "source": [ "p = p.facet(\"island\")\n", "p.add(so.Bars(), so.Hist(stat=\"proportion\"))" ] }, { "cell_type": "raw", "id": "f7afc403-26cc-4325-a28a-913c2291aa35", "metadata": {}, "source": [ "Pass `common_norm=False` to normalize each distribution independently:" ] }, { "cell_type": "code", "execution_count": null, "id": "b2029324-069f-4261-a178-1efad2fd0e88", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(stat=\"proportion\", common_norm=False))" ] }, { "cell_type": "raw", "id": "0f83401a-e456-4a14-af69-f1483c6c03c4", "metadata": {}, "source": [ "Or, with more than one grouping varible, specify a subset to normalize within:" ] }, { "cell_type": "code", "execution_count": null, "id": "5c092262-8a8f-4a3e-8cae-9e0f23dd94ba", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(stat=\"proportion\", common_norm=[\"col\"]), color=\"sex\")" ] }, { "cell_type": "raw", "id": "86532133-bf33-4674-9614-86ae3408aa51", "metadata": {}, "source": [ "When distributions overlap it may be easier to discern their shapes with an :class:`Area` mark:" ] }, { "cell_type": "code", "execution_count": null, "id": "00b18ad8-52d4-460a-a012-d87c66b3e71e", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), so.Hist(), color=\"sex\")" ] }, { "cell_type": "raw", "id": "2b34d435-abbf-41aa-b219-91883d7d29f3", "metadata": {}, "source": [ "Or add :class:`Stack` move to represent a part-whole relationship:" ] }, { "cell_type": "code", "execution_count": null, "id": "3a7a0c05-d774-4f99-950f-5dc9865027c4", "metadata": {}, "outputs": [], "source": [ "p.add(so.Bars(), so.Hist(), so.Stack(), color=\"sex\")" ] }, { "cell_type": "code", "execution_count": null, "id": "e247e74b-2c09-40f0-8f45-9fa5f8264d78", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Jitter.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "f2e5a85d-c710-492b-a4fc-09b45ae26471", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "14b5927c-42f1-4934-adee-3d380b8b3228", "metadata": {}, "source": [ "When used without any arguments, a small amount of jitter will be applied along the orientation axis:" ] }, { "cell_type": "code", "execution_count": null, "id": "bc1b4941-bbe6-4afc-b51a-0ac67cbe417d", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, \"species\", \"body_mass_g\")\n", " .add(so.Dots(), so.Jitter())\n", ")" ] }, { "cell_type": "raw", "id": "1101690e-6c19-4219-aa4e-180798454df1", "metadata": {}, "source": [ "The `width` parameter controls the amount of jitter relative to the spacing between the marks:" ] }, { "cell_type": "code", "execution_count": null, "id": "c4251b9d-8b11-4c2c-905c-2f3b523dee70", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, \"species\", \"body_mass_g\")\n", " .add(so.Dots(), so.Jitter(.5))\n", ")" ] }, { "cell_type": "raw", "id": "38aa639a-356e-4674-970b-53d55379b2b7", "metadata": {}, "source": [ "The `width` parameter always applies to the orientation axis, so the direction of jitter will adapt along with the orientation:" ] }, { "cell_type": "code", "execution_count": null, "id": "1cfe1c07-7e81-45a0-a989-240503046133", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, \"body_mass_g\", \"species\")\n", " .add(so.Dots(), so.Jitter(.5))\n", ")" ] }, { "cell_type": "raw", "id": "0f5de4cc-3383-4503-8b59-9c48230a12a5", "metadata": {}, "source": [ "Because the `width` jitter is relative, it can be used when the orientation axis is numeric without further tweaking:" ] }, { "cell_type": "code", "execution_count": null, "id": "c94c41e8-29c4-4439-a5d1-0b8ffb244890", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins[\"body_mass_g\"].round(-3), penguins[\"flipper_length_mm\"])\n", " .add(so.Dots(), so.Jitter())\n", ")" ] }, { "cell_type": "raw", "id": "dd982dfa-fd9f-4edc-8190-18f0e101ae1a", "metadata": {}, "source": [ "In contrast to `width`, the `x` and `y` parameters always refer to specific axes and control the jitter in data units:" ] }, { "cell_type": "code", "execution_count": null, "id": "b0f2e5ca-68ad-4439-a4ee-f32f65682e95", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins[\"body_mass_g\"].round(-3), penguins[\"flipper_length_mm\"])\n", " .add(so.Dots(), so.Jitter(x=100))\n", ")" ] }, { "cell_type": "raw", "id": "a90ba526-8043-42ed-8f57-36445c163c0d", "metadata": {}, "source": [ "Both `x` and `y` can be used in a single transform:" ] }, { "cell_type": "code", "execution_count": null, "id": "6c07ed1d-ac77-4b30-90a8-e1b8760f9fad", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(\n", " penguins[\"body_mass_g\"].round(-3),\n", " penguins[\"flipper_length_mm\"].round(-1),\n", " )\n", " .add(so.Dots(), so.Jitter(x=200, y=5))\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "bb04c7a2-93f0-44cf-aacf-0eb436d0f14b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.KDE.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "dcc1ae12-bba4-4de9-af8d-543b3d65b42b", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "1042b991-1471-43bd-934c-43caae3cb2fa", "metadata": {}, "source": [ "This stat estimates transforms observations into a smooth function representing the estimated density:" ] }, { "cell_type": "code", "execution_count": null, "id": "2406e2aa-7f0f-4a51-af59-4cef827d28d8", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(penguins, x=\"flipper_length_mm\")\n", "p.add(so.Area(), so.KDE())" ] }, { "cell_type": "raw", "id": "44515f21-683b-420f-967b-4c7568c907d7", "metadata": {}, "source": [ "Adjust the smoothing bandwidth to see more or fewer details:" ] }, { "cell_type": "code", "execution_count": null, "id": "d4e6ba5b-4dd2-4210-8cf0-de057dc71e2a", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), so.KDE(bw_adjust=0.25))" ] }, { "cell_type": "raw", "id": "fd665fe1-a5e4-4742-adc9-e40615d57d08", "metadata": {}, "source": [ "The curve will extend beyond observed values in the dataset:" ] }, { "cell_type": "code", "execution_count": null, "id": "4cda1cb8-f663-4f94-aa24-6f1727a41031", "metadata": {}, "outputs": [], "source": [ "p2 = p.add(so.Bars(alpha=.3), so.Hist(\"density\"))\n", "p2.add(so.Line(), so.KDE())" ] }, { "cell_type": "raw", "id": "75235825-d522-4562-aacc-9b7413eabf5d", "metadata": {}, "source": [ "Control the range of the density curve relative to the observations using `cut`:" ] }, { "cell_type": "code", "execution_count": null, "id": "a7a9275e-9889-437d-bdc5-18653d2c92ef", "metadata": {}, "outputs": [], "source": [ "p2.add(so.Line(), so.KDE(cut=0))" ] }, { "cell_type": "raw", "id": "6a885eeb-81ba-47c6-8402-1bef40544fd1", "metadata": {}, "source": [ "When observations are assigned to the `y` variable, the density will be shown for `x`:" ] }, { "cell_type": "code", "execution_count": null, "id": "38b3a0fb-54ff-493a-bd64-f83a12365723", "metadata": {}, "outputs": [], "source": [ "so.Plot(penguins, y=\"flipper_length_mm\").add(so.Area(), so.KDE())" ] }, { "cell_type": "raw", "id": "59996340-168e-479f-a0c6-c7e1fcab0fb0", "metadata": {}, "source": [ "Use `gridsize` to increase or decrease the resolution of the grid where the density is evaluated:" ] }, { "cell_type": "code", "execution_count": null, "id": "23715820-7df9-40ba-9e74-f11564704dd0", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dots(), so.KDE(gridsize=100))" ] }, { "cell_type": "raw", "id": "4c9b6492-98c8-45ab-9f53-681cde2f767a", "metadata": {}, "source": [ "Or pass `None` to evaluate the density at the original datapoints:" ] }, { "cell_type": "code", "execution_count": null, "id": "4e1b6810-5c28-43aa-aa61-652521299b51", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dots(), so.KDE(gridsize=None))" ] }, { "cell_type": "raw", "id": "0970a56b-0cba-4c40-bb1b-b8e71739df5c", "metadata": {}, "source": [ "Other variables will define groups for the estimation:" ] }, { "cell_type": "code", "execution_count": null, "id": "5f0ce0b6-5742-4bc0-9ac3-abedde923684", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), so.KDE(), color=\"species\")" ] }, { "cell_type": "raw", "id": "22204fcd-4b25-46e5-a170-02b1419c23d5", "metadata": {}, "source": [ "By default, the density is normalized across all groups (i.e., the joint density is shown); pass `common_norm=False` to show conditional densities:" ] }, { "cell_type": "code", "execution_count": null, "id": "6ad56958-dc45-4632-94d1-23039ad3ec58", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), so.KDE(common_norm=False), color=\"species\")" ] }, { "cell_type": "raw", "id": "b1627197-85d1-4476-b4ae-3e93044ee988", "metadata": {}, "source": [ "Or pass a list of variables to condition on:" ] }, { "cell_type": "code", "execution_count": null, "id": "58f63734-5afd-4d90-bbfb-fc39c8d1981f", "metadata": {}, "outputs": [], "source": [ "(\n", " p.facet(\"sex\")\n", " .add(so.Area(), so.KDE(common_norm=[\"col\"]), color=\"species\")\n", ")" ] }, { "cell_type": "raw", "id": "2b7e018e-1374-4939-909c-e95f5ffd086e", "metadata": {}, "source": [ "This stat can be combined with other transforms, such as :class:`Stack` (when `common_grid=True`):" ] }, { "cell_type": "code", "execution_count": null, "id": "96e5b2d0-c7e2-47df-91f1-7f9ec0bb08a9", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), so.KDE(), so.Stack(), color=\"sex\")" ] }, { "cell_type": "raw", "id": "8500ff86-0b1f-4831-954b-08b6df690387", "metadata": {}, "source": [ "Set `cumulative=True` to integrate the density:" ] }, { "cell_type": "code", "execution_count": null, "id": "26bb736e-7cfd-421e-b80d-42fa450e88c0", "metadata": {}, "outputs": [], "source": [ "p.add(so.Line(), so.KDE(cumulative=True))" ] }, { "cell_type": "code", "execution_count": null, "id": "e8bfd9d2-ad60-4971-aa7f-71a285f44a20", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Line.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "dowjones = load_dataset(\"dowjones\")\n", "fmri = load_dataset(\"fmri\")" ] }, { "cell_type": "markdown", "id": "05468ecf-d2f5-46f0-ba43-ea13aba0ebd2", "metadata": {}, "source": [ "The mark draws a connecting line between sorted observations:" ] }, { "cell_type": "code", "execution_count": null, "id": "acd5788f-e62b-497c-a109-f0bc02b8cae9", "metadata": {}, "outputs": [], "source": [ "so.Plot(dowjones, \"Date\", \"Price\").add(so.Line())" ] }, { "cell_type": "markdown", "id": "94efb077-49a5-4214-891a-c68f89c79926", "metadata": {}, "source": [ "Change the orientation to connect observations along the opposite axis (`orient=\"y\"` is redundant here; the plot would detect that the date variable has a lower orientation priority than the price variable):" ] }, { "cell_type": "code", "execution_count": null, "id": "4c5db48f-1c88-4905-a5f5-2ae96ceb0f95", "metadata": {}, "outputs": [], "source": [ "so.Plot(dowjones, x=\"Price\", y=\"Date\").add(so.Line(), orient=\"y\")" ] }, { "cell_type": "raw", "id": "77bd0b1e-d9d1-4741-9821-83cec708e877", "metadata": {}, "source": [ "To replicate the same line multiple times, assign a `group` variable (but consider using :class:`Lines` here instead):" ] }, { "cell_type": "code", "execution_count": null, "id": "2c1b699c-4e42-4461-a7fb-0d664ef8fe1b", "metadata": {}, "outputs": [], "source": [ "(\n", " fmri\n", " .query(\"region == 'parietal' and event == 'stim'\")\n", " .pipe(so.Plot, \"timepoint\", \"signal\")\n", " .add(so.Line(color=\".2\", linewidth=1), group=\"subject\")\n", ")" ] }, { "cell_type": "raw", "id": "c09cc6a1-a86b-48b7-b276-e0e9125d279e", "metadata": {}, "source": [ "When mapping variables to properties like `color` or `linestyle`, stat transforms are computed within each grouping:" ] }, { "cell_type": "code", "execution_count": null, "id": "83b8c68d-a1ae-4bfb-b3dc-4a11bbe85cbc", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(fmri, \"timepoint\", \"signal\", color=\"region\", linestyle=\"event\")\n", "p.add(so.Line(), so.Agg())" ] }, { "cell_type": "raw", "id": "c9390f58-0fb1-47ba-8b86-bde4c41e6d1d", "metadata": {}, "source": [ "Combine with :class:`Band` to show an error bar:" ] }, { "cell_type": "code", "execution_count": null, "id": "b6ab0006-0f28-4992-b687-41889a424684", "metadata": {}, "outputs": [], "source": [ "(\n", " p\n", " .add(so.Line(), so.Agg())\n", " .add(so.Band(), so.Est(), group=\"event\")\n", ")" ] }, { "cell_type": "raw", "id": "e567df5c-6675-423f-bcd8-94cb3a400251", "metadata": {}, "source": [ "Add markers to indicate values where the data were sampled:" ] }, { "cell_type": "code", "execution_count": null, "id": "2541701c-1a2c-44dd-b300-6551861c8b98", "metadata": {}, "outputs": [], "source": [ "p.add(so.Line(marker=\"o\", edgecolor=\"w\"), so.Agg(), linestyle=None)" ] }, { "cell_type": "code", "execution_count": null, "id": "a25d0379-b374-4539-82a4-00ce37245e1b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Lines.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "seaice = load_dataset(\"seaice\")" ] }, { "cell_type": "raw", "id": "09694cb8-4867-49fc-80a6-a4551e50b77e", "metadata": {}, "source": [ "Like :class:`Line`, the mark draws a connecting line between sorted observations:" ] }, { "cell_type": "code", "execution_count": null, "id": "acd5788f-e62b-497c-a109-f0bc02b8cae9", "metadata": {}, "outputs": [], "source": [ "so.Plot(seaice, \"Date\", \"Extent\").add(so.Lines())" ] }, { "cell_type": "raw", "id": "8f982f2d-1119-4842-9860-80b415fd24fe", "metadata": {}, "source": [ "Compared to :class:`Line`, this mark offers fewer settable properties, but it can have better performance when drawing a large number of lines:" ] }, { "cell_type": "code", "execution_count": null, "id": "d4411136-1787-47ca-91f4-4ecba541e575", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(\n", " x=seaice[\"Date\"].dt.day_of_year,\n", " y=seaice[\"Extent\"],\n", " color=seaice[\"Date\"].dt.year\n", " )\n", " .facet(seaice[\"Date\"].dt.year.round(-1))\n", " .add(so.Lines(linewidth=.5, color=\"#bbca\"), col=None)\n", " .add(so.Lines(linewidth=1))\n", " .scale(color=\"ch:rot=-.2,light=.7\")\n", " .layout(size=(8, 4))\n", " .label(title=\"{}s\".format)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "aaab3914-77d7-4d09-bdbe-f057a2fe28cf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Norm.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "0bfee8b6-1e3e-499d-96ae-735a5c230b32", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "healthexp = load_dataset(\"healthexp\")" ] }, { "cell_type": "raw", "id": "43adf565-2843-48fe-a12a-1a65bc9fce9f", "metadata": {}, "source": [ "By default, this transform scales each group relative to its maximum value:" ] }, { "cell_type": "code", "execution_count": null, "id": "6262c89d-56cd-41b4-8276-0bf737b02f29", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(healthexp, x=\"Year\", y=\"Spending_USD\", color=\"Country\")\n", " .add(so.Lines(), so.Norm())\n", " .label(y=\"Spending relative to maximum amount\")\n", ")" ] }, { "cell_type": "raw", "id": "5941b47a-7f2f-4540-9944-c6a16e7eec75", "metadata": {}, "source": [ "Use `where` to constrain the values used to define a baseline, and `percent` to scale the output:" ] }, { "cell_type": "code", "execution_count": null, "id": "8142d0b4-1b91-4ba9-bc60-3df148130ff9", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(healthexp, x=\"Year\", y=\"Spending_USD\", color=\"Country\")\n", " .add(so.Lines(), so.Norm(where=\"x == x.min()\", percent=True))\n", " .label(y=\"Percent change in spending from 1970 baseline\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "2f2d2d33-8a92-44fb-b37a-24dee23a7d75", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Path.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "healthexp = load_dataset(\"healthexp\").sort_values([\"Country\", \"Year\"])" ] }, { "cell_type": "raw", "id": "8c2781ed-190d-4155-99ac-0170b94de030", "metadata": {}, "source": [ "Unlike :class:`Line`, this mark does not sort observations before plotting, making it suitable for plotting trajectories through a variable space:" ] }, { "cell_type": "code", "execution_count": null, "id": "199c0b22-1cbd-4b5a-bebe-f59afa79b9c6", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(healthexp, \"Spending_USD\", \"Life_Expectancy\", color=\"Country\")\n", "p.add(so.Path())" ] }, { "cell_type": "raw", "id": "fb87bd85-024b-42f5-b458-3550271d7124", "metadata": {}, "source": [ "It otherwise offers the same set of options, including a number of properties that can be set or mapped:" ] }, { "cell_type": "code", "execution_count": null, "id": "280de309-1c0d-4cdc-8f4c-a4f15da461cf", "metadata": {}, "outputs": [], "source": [ "p.add(so.Path(marker=\"o\", pointsize=2, linewidth=.75, fillcolor=\"w\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "4e795770-4481-4e23-a49b-e828a1f5cbbd", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Paths.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "networks = (\n", " load_dataset(\"brain_networks\", header=[0, 1, 2], index_col=0)\n", " .rename_axis(\"timepoint\")\n", " .stack([0, 1, 2])\n", " .groupby([\"timepoint\", \"network\", \"hemi\"])\n", " .mean()\n", " .unstack(\"network\")\n", " .reset_index()\n", " .query(\"timepoint < 100\")\n", ")" ] }, { "cell_type": "raw", "id": "50646936-5236-413f-b79b-6c3b640ade04", "metadata": {}, "source": [ "Unlike :class:`Lines`, this mark does not sort observations before plotting, making it suitable for plotting trajectories through a variable space:" ] }, { "cell_type": "code", "execution_count": null, "id": "4a3ed115-cc47-4ea8-be46-2c99f7453941", "metadata": {}, "outputs": [], "source": [ "p = (\n", " so.Plot(networks)\n", " .pair(\n", " x=[\"5\", \"8\", \"12\", \"15\"],\n", " y=[\"6\", \"13\", \"16\"],\n", " )\n", " .layout(size=(8, 5))\n", " .share(x=True, y=True)\n", ")\n", "p.add(so.Paths())" ] }, { "cell_type": "raw", "id": "5bf502eb-feb3-4b2e-882b-3e915bf5d041", "metadata": {}, "source": [ "The mark has the same set of properties as :class:`Lines`:" ] }, { "cell_type": "code", "execution_count": null, "id": "326a765b-59f0-46ef-91c2-6705c6893740", "metadata": {}, "outputs": [], "source": [ "p.add(so.Paths(linewidth=1, alpha=.8), color=\"hemi\")" ] }, { "cell_type": "code", "execution_count": null, "id": "175b836d-d328-4b6c-ad36-dde18c19e3bf", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Perc.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2d44a326-029b-47ff-b560-5f4b6a4bb73f", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "diamonds = load_dataset(\"diamonds\")" ] }, { "cell_type": "raw", "id": "65e975a2-2559-4bf1-8851-8bbbf52bf22d", "metadata": {}, "source": [ "The default behavior computes the quartiles and min/max of the input data:" ] }, { "cell_type": "code", "execution_count": null, "id": "36f927f5-3b64-4871-a355-adadc4da769b", "metadata": {}, "outputs": [], "source": [ "p = (\n", " so.Plot(diamonds, \"cut\", \"price\")\n", " .scale(y=\"log\")\n", ")\n", "p.add(so.Dot(), so.Perc())" ] }, { "cell_type": "raw", "id": "feba1b99-0f71-4b18-8e7e-bd5470cc2d0c", "metadata": {}, "source": [ "Passing an integer will compute that many evenly-spaced percentiles:" ] }, { "cell_type": "code", "execution_count": null, "id": "f030dd39-1223-475a-93e1-1759a8971a6c", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dot(), so.Perc(20))" ] }, { "cell_type": "raw", "id": "85bd754b-122e-4475-8727-2d584a90a38e", "metadata": {}, "source": [ "Passing a list will compute exactly those percentiles:" ] }, { "cell_type": "code", "execution_count": null, "id": "2fde7549-45b5-411a-afba-eb0da754d9e9", "metadata": {}, "outputs": [], "source": [ "p.add(so.Dot(), so.Perc([10, 25, 50, 75, 90]))" ] }, { "cell_type": "raw", "id": "7be16a13-dfc8-4595-a904-42f9be10f4f6", "metadata": {}, "source": [ "Combine with a range mark to show a percentile interval:" ] }, { "cell_type": "code", "execution_count": null, "id": "05c561c6-0449-4a61-96d1-390611a1b694", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(diamonds, \"price\", \"cut\")\n", " .add(so.Dots(pointsize=1, alpha=.2), so.Jitter(.3))\n", " .add(so.Range(color=\"k\"), so.Perc([25, 75]), so.Shift(y=.2))\n", " .scale(x=\"log\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "d464157c-3187-49c1-9cd8-71f284ce4c50", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.add.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "tips = load_dataset(\"tips\")" ] }, { "cell_type": "raw", "id": "33cd5d3c-d3ad-4e3b-bdac-350f8e104594", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Every layer must be defined with a :class:`Mark`:" ] }, { "cell_type": "code", "execution_count": null, "id": "43d0401a-d7d5-4746-a02f-a48f8b5fd1f2", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(tips, \"total_bill\", \"tip\").add(so.Dot())\n", "p" ] }, { "cell_type": "raw", "id": "34b4f581-6126-4d57-ac76-8821c5daa97b", "metadata": {}, "source": [ "Call :class:`Plot.add` multiple times to add multiple layers. In addition to the :class:`Mark`, layers can also be defined with :class:`Stat` or :class:`Move` transforms:" ] }, { "cell_type": "code", "execution_count": null, "id": "693c461e-1dc2-4b44-a9e5-c07b1bf0108b", "metadata": {}, "outputs": [], "source": [ "p.add(so.Line(), so.PolyFit())" ] }, { "cell_type": "raw", "id": "96a61426-0de2-4f4b-a373-0006da6fcceb", "metadata": {}, "source": [ "Multiple transforms can be stacked into a pipeline. " ] }, { "cell_type": "code", "execution_count": null, "id": "b22623a7-bfde-493c-8593-76b145fa1e84", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, y=\"day\", color=\"sex\")\n", " .add(so.Bar(), so.Hist(), so.Dodge())\n", ")" ] }, { "cell_type": "raw", "id": "aa8e6bde-c86c-4bd8-abbe-e0fc64103114", "metadata": {}, "source": [ "Layers have an \"orientation\", which affects the transforms and some marks. The orientation is typically inferred from the variable types assigned to `x` and `y`, but it can be specified when it would otherwise be ambiguous:" ] }, { "cell_type": "code", "execution_count": null, "id": "42be495b-e41b-4883-b061-0973c0e8b496", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"size\", color=\"time\")\n", " .add(so.Dot(alpha=.5), so.Dodge(), so.Jitter(.4), orient=\"y\")\n", ")" ] }, { "cell_type": "raw", "id": "0d2a77f2-6a21-4fe6-a8b1-66978f4f072b", "metadata": {}, "source": [ "Variables can be assigned to a specific layer. Note the distinction between how `pointsize` is passed to :class:`Plot.add` — so it is *mapped* by a scale — while `color` and `linewidth` are passed directly to :class:`Line`, so they directly set the line's color and width:" ] }, { "cell_type": "code", "execution_count": null, "id": "e42c3699-c468-4c21-b417-3952311735eb", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, \"total_bill\", \"tip\")\n", " .add(so.Dots(), pointsize=\"size\")\n", " .add(so.Line(color=\".3\", linewidth=3), so.PolyFit())\n", " .scale(pointsize=(2, 10))\n", ")" ] }, { "cell_type": "raw", "id": "d61908e5-9074-443d-9160-2c3101a39bcd", "metadata": {}, "source": [ "Variables that would otherwise apply to the entire plot can also be *excluded* from a specific layer by setting their value to `None`:" ] }, { "cell_type": "code", "execution_count": null, "id": "a095ecca-b428-4bad-a9ab-4d4f05cf61e0", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, \"total_bill\", \"tip\", color=\"day\")\n", " .facet(col=\"day\")\n", " .add(so.Dot(color=\"#aabc\"), col=None, color=None)\n", " .add(so.Dot())\n", ")" ] }, { "cell_type": "raw", "id": "60f94773-668e-441e-9634-41473c26d3bd", "metadata": {}, "source": [ "Variables used only by the transforms *must* be passed at the layer level:" ] }, { "cell_type": "code", "execution_count": null, "id": "0d1ac7e8-5bbd-4a1a-a207-197a4251c2d3", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, \"day\")\n", " .add(so.Bar(), so.Hist(), weight=\"size\")\n", " .label(y=\"Total patrons\")\n", ")" ] }, { "cell_type": "raw", "id": "8a7a5ff7-c0f5-4787-8908-3cb13ea7a047", "metadata": {}, "source": [ "Each layer can be provided with its own data source. If a data source was provided in the constructor, the layer data will be joined using its index:" ] }, { "cell_type": "code", "execution_count": null, "id": "45690aaa-1abf-40ae-be3b-1ab648f8be62", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "(\n", " so.Plot(tips, \"total_bill\", \"tip\")\n", " .add(so.Dot(color=\"#aabc\"))\n", " .add(so.Dot(), data=tips.query(\"size == 2\"), color=\"time\")\n", ")" ] }, { "cell_type": "raw", "id": "e62f9e80-bfba-4516-a43a-a265dc35eb79", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Providing a `label` will annotate the layer in the plot's legend:" ] }, { "cell_type": "code", "execution_count": null, "id": "a403012a-e895-4e5b-b690-dc27efbeccad", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"size\")\n", " .add(so.Line(color=\"C1\"), so.Agg(), y=\"total_bill\", label=\"Bill\")\n", " .add(so.Line(color=\"C2\"), so.Agg(), y=\"tip\", label=\"Tip\")\n", " .label(y=\"Value\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "c14526a4-37bb-4f4c-84fa-e5c556eee5c2", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.config.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "a38a6fed-51de-4dbc-8d5b-4971d06acf2e", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so" ] }, { "cell_type": "raw", "id": "38081259-9382-4623-8d67-09aa114e0949", "metadata": {}, "source": [ "Theme configuration\n", "^^^^^^^^^^^^^^^^^^^\n", "\n", "Theme changes made through the the :attr:`Plot.config` interface will apply to all subsequent :class:`Plot` instances. Use the :meth:`Plot.theme` method to modify the theme on a plot-by-plot basis.\n", "\n", "The theme is a dictionary of matplotlib `rc parameters `_. You can set individual parameters directly:" ] }, { "cell_type": "code", "execution_count": null, "id": "34ca0ce9-5284-47b6-8281-180709dbec89", "metadata": {}, "outputs": [], "source": [ "so.Plot.config.theme[\"axes.facecolor\"] = \"white\"" ] }, { "cell_type": "raw", "id": "b3f93646-8370-4c16-ace4-7bb811688758", "metadata": {}, "source": [ "To change the overall style of the plot, update the theme with a dictionary of parameters, perhaps from one of seaborn's theming functions:" ] }, { "cell_type": "code", "execution_count": null, "id": "8e5eb7d3-cc7a-4231-b887-db37045f3db4", "metadata": {}, "outputs": [], "source": [ "from seaborn import axes_style\n", "so.Plot.config.theme.update(axes_style(\"whitegrid\"))" ] }, { "cell_type": "raw", "id": "f7c7bd9c-722d-45db-902a-c2dcdef571ee", "metadata": {}, "source": [ "To sync :class:`Plot` with matplotlib's global state, pass the `rcParams` dictionary:" ] }, { "cell_type": "code", "execution_count": null, "id": "fd1cd96e-1a2c-474a-809f-20b8c4794578", "metadata": {}, "outputs": [], "source": [ "import matplotlib as mpl\n", "so.Plot.config.theme.update(mpl.rcParams)" ] }, { "cell_type": "raw", "id": "7e305ec1-4a83-411f-91df-aee2ec4d1806", "metadata": {}, "source": [ "The theme can also be reset back to seaborn defaults:" ] }, { "cell_type": "code", "execution_count": null, "id": "e3146b1d-1b5e-464f-a631-e6d6caf161b3", "metadata": {}, "outputs": [], "source": [ "so.Plot.config.theme.reset()" ] }, { "cell_type": "raw", "id": "b6370088-02f6-4933-91c0-5763b86b7299", "metadata": {}, "source": [ "Display configuration\n", "^^^^^^^^^^^^^^^^^^^^^\n", "\n", "When returned from the last statement in a notebook cell, a :class:`Plot` will be compiled and embedded in the notebook as an image. By default, the image is rendered as HiDPI PNG. Alternatively, it is possible to display the plots in SVG format:" ] }, { "cell_type": "code", "execution_count": null, "id": "1bd9966e-d08f-46b4-ad44-07276d5efba8", "metadata": {}, "outputs": [], "source": [ "so.Plot.config.display[\"format\"] = \"svg\"" ] }, { "cell_type": "raw", "id": "845239ed-3a0f-4a94-97d0-364c2db3b9c8", "metadata": {}, "source": [ "SVG images use vector graphics with \"infinite\" resolution, so they will appear crisp at any amount of zoom. The downside is that each plot element is drawn separately, so the image data can get very heavy for certain kinds of plots (e.g., for dense scatterplots).\n", "\n", "The HiDPI scaling of the default PNG images will also inflate the size of the notebook they are stored in. (Unlike with SVG, PNG size will scale with the dimensions of the plot but not its complexity). When not useful, it can be disabled:" ] }, { "cell_type": "code", "execution_count": null, "id": "13ac09f7-d4ad-4b4e-8963-edc0c6c71a94", "metadata": {}, "outputs": [], "source": [ "so.Plot.config.display[\"hidpi\"] = False" ] }, { "cell_type": "raw", "id": "ddebe3eb-1d64-41e9-9cfd-f8359d6f8a38", "metadata": {}, "source": [ "The embedded images are scaled down slightly — independently from the figure size or DPI — so that more information can be presented on the screen. The precise scaling factor is also configurable:" ] }, { "cell_type": "code", "execution_count": null, "id": "f10c5596-598d-4258-bf8f-67c07eaba266", "metadata": {}, "outputs": [], "source": [ "so.Plot.config.display[\"scaling\"] = 0.7" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.facet.ipynb ================================================ { "cells": [ { "cell_type": "raw", "id": "fb8e120d-5dcf-483b-a0d1-74857d09ce7d", "metadata": {}, "source": [ ".. currentmodule:: seaborn.objects" ] }, { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")\n", "diamonds = load_dataset(\"diamonds\")" ] }, { "cell_type": "markdown", "id": "ae85e302-354c-46ca-a17f-aaec7ed1cbd6", "metadata": {}, "source": [ "Assigning a faceting variable will create multiple subplots and plot subsets of the data on each of them:" ] }, { "cell_type": "code", "execution_count": null, "id": "d65405fd-cf28-4248-8e51-1aa1999354a2", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(penguins, \"bill_length_mm\", \"bill_depth_mm\").add(so.Dots())\n", "p.facet(\"species\")" ] }, { "cell_type": "markdown", "id": "2b9630aa-3b46-4e72-82ef-5717c2d8c686", "metadata": {}, "source": [ "Multiple faceting variables can be defined to create a two-dimensional grid:" ] }, { "cell_type": "code", "execution_count": null, "id": "1857144f-1373-4704-9332-d3fc649ceb9d", "metadata": {}, "outputs": [], "source": [ "p.facet(\"species\", \"sex\")" ] }, { "cell_type": "markdown", "id": "7664e2d2-c254-44b4-9973-88e1d013fb3d", "metadata": {}, "source": [ "Facet variables can be provided as references to the global plot data or as vectors:" ] }, { "cell_type": "code", "execution_count": null, "id": "6569616d-480b-4b8c-a761-f5bd2bde60e3", "metadata": {}, "outputs": [], "source": [ "p.facet(penguins[\"island\"])" ] }, { "cell_type": "markdown", "id": "198f63a0-bb0f-40c4-b790-bd15f8656acb", "metadata": {}, "source": [ "With a single faceting variable, arrange the facets or limit to a subset by passing a list of levels to `order`:" ] }, { "cell_type": "code", "execution_count": null, "id": "b1344f7f-50d0-4592-b4fb-ab81d97a4798", "metadata": {}, "outputs": [], "source": [ "p.facet(\"species\", order=[\"Gentoo\", \"Adelie\"])" ] }, { "cell_type": "markdown", "id": "2090297c-414f-4448-a930-5b6f0de18deb", "metadata": {}, "source": [ "With multiple variables, pass `order` as a dictionary:" ] }, { "cell_type": "code", "execution_count": null, "id": "58ed1b13-71a7-462a-af99-78be566268a6", "metadata": {}, "outputs": [], "source": [ "p.facet(\"species\", \"sex\", order={\"col\": [\"Gentoo\", \"Adelie\"], \"row\": [\"Female\", \"Male\"]})" ] }, { "cell_type": "markdown", "id": "e440f14d-24b2-4f83-a247-0bb917f9f4c3", "metadata": {}, "source": [ "When the faceting variable has multiple levels, you can `wrap` it to distribute subplots across both dimensions:" ] }, { "cell_type": "code", "execution_count": null, "id": "92baf66c-6dd9-4f50-adf2-386c4daab094", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(diamonds, x=\"carat\", y=\"price\").add(so.Dots())\n", "p.facet(\"color\", wrap=4)" ] }, { "cell_type": "markdown", "id": "8d0872cb-e261-4796-b81e-a416fea85201", "metadata": {}, "source": [ "Wrapping works only when there is a single variable, but you can wrap in either direction:" ] }, { "cell_type": "code", "execution_count": null, "id": "c5a66a64-bfba-437c-80be-1311e85cf5a5", "metadata": {}, "outputs": [], "source": [ "p.facet(row=\"color\", wrap=2)" ] }, { "cell_type": "raw", "id": "e1bdaad7-5883-45ad-af39-c10183569bdc", "metadata": {}, "source": [ "Use :meth:`Plot.share` to specify whether facets should be scaled the same way:" ] }, { "cell_type": "code", "execution_count": null, "id": "14c1f977-79d4-4f9c-a846-1fd70ad3569e", "metadata": {}, "outputs": [], "source": [ "p.facet(\"clarity\", wrap=3).share(x=False)" ] }, { "cell_type": "raw", "id": "a4fc64d9-b7ba-4061-8160-63d8fd89e47a", "metadata": {}, "source": [ "Use :meth:`Plot.label` to tweak the titles:" ] }, { "cell_type": "code", "execution_count": null, "id": "4206b12c-d7a3-419f-b278-6edfe487c5de", "metadata": {}, "outputs": [], "source": [ "p.facet(\"color\").label(title=\"{} grade\".format)" ] }, { "cell_type": "code", "execution_count": null, "id": "28b4fb9d-2bb0-40ff-a541-5f300aca6200", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.label.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "fb32137a-e882-4222-9463-b8cf0ee1c8bd", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Use strings to override default labels:" ] }, { "cell_type": "code", "execution_count": null, "id": "65b4320e-6fb9-48ed-9132-53b0d21b85e6", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "p = (\n", " so.Plot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", " .add(so.Dot(), color=\"species\")\n", ")\n", "p.label(x=\"Length\", y=\"Depth\", color=\"\")" ] }, { "cell_type": "raw", "id": "a39626d2-76f5-40a9-a3fd-6f44dd69bd30", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Pass a function to *modify* the default label:" ] }, { "cell_type": "code", "execution_count": null, "id": "c3540c54-1c91-4d55-8f58-cd758abbe2fd", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "p.label(color=str.capitalize)" ] }, { "cell_type": "markdown", "id": "68f3b321-0755-4ef1-a9e6-bcff61a9178d", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "Use this method to set the title for a single-axes plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "12d23c6e-781f-4b5c-a6b0-3ea0317ab7fb", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "p.label(title=\"Penguin species exhibit distinct bill shapes\")" ] }, { "cell_type": "markdown", "id": "8e0bcb80-0929-4ab9-b5c0-13bb3d8e4484", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "When faceting, the `title` parameter will modify default titles:" ] }, { "cell_type": "code", "execution_count": null, "id": "da1516b7-b823-41c0-b251-01bdecb6a4e6", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "p.facet(\"sex\").label(title=str.upper)" ] }, { "cell_type": "markdown", "id": "bb439eae-6cc3-4a6c-bef2-b4b7746edbd1", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "And the `col`/`row` parameters will add labels to the title for each facet:" ] }, { "cell_type": "code", "execution_count": null, "id": "e0d49ba9-0507-4358-b477-2e0253f0df8f", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "p.facet(\"sex\").label(col=\"Sex:\")" ] }, { "cell_type": "markdown", "id": "99471c06-1b1a-4ef5-844c-5f4aa8f322f5", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "If more customization is needed, a format string can work well:" ] }, { "cell_type": "code", "execution_count": null, "id": "848be3a3-5a2c-4b98-918f-825257be85ae", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "p.facet(\"sex\").label(title=\"{} penguins\".format)" ] }, { "cell_type": "code", "execution_count": null, "id": "94012def-dd7c-48f4-8830-f77a3bf7299b", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "p" ] }, { "cell_type": "raw", "id": "e9b669e9-fd3d-4292-9c8d-e5fb093932b2", "metadata": { "editable": true, "raw_mimetype": "", "slideshow": { "slide_type": "" }, "tags": [] }, "source": [ "When adding labels for each layer, the `legend=` parameter sets the title for the legend:" ] }, { "cell_type": "code", "execution_count": null, "id": "78d22763-3f92-4be1-bc3f-bc24ad39da70", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\")\n", " .add(so.Line(color=\"C1\"), so.Agg(), y=\"bill_length_mm\", label=\"length\")\n", " .add(so.Line(color=\"C2\"), so.Agg(), y=\"bill_depth_mm\", label=\"depth\")\n", " .label(legend=\"Measurement\")\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "5c7a7b91-bb5c-4bf5-99f8-719a220e3b36", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [] }, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.layout.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so" ] }, { "cell_type": "markdown", "id": "406f8f8d-b590-46f4-a230-626e32e52c71", "metadata": {}, "source": [ "Control the overall dimensions of the figure with `size`:" ] }, { "cell_type": "code", "execution_count": null, "id": "fefc2b45-3510-4cd7-9de9-4806d71fc4c1", "metadata": {}, "outputs": [], "source": [ "p = so.Plot().layout(size=(4, 4))\n", "p" ] }, { "cell_type": "raw", "id": "909a47bb-82f5-455a-99c3-7049d548561b", "metadata": {}, "source": [ "Subplots created by using :meth:`Plot.facet` or :meth:`Plot.pair` will shrink to fit in the available space:" ] }, { "cell_type": "code", "execution_count": null, "id": "3163687c-8d48-4e88-8dc2-35e16341e30e", "metadata": {}, "outputs": [], "source": [ "p.facet([\"A\", \"B\"], [\"X\", \"Y\"])" ] }, { "cell_type": "markdown", "id": "feda7c3a-3862-48d4-bb18-419cd03fc081", "metadata": {}, "source": [ "You may find that different automatic layout engines give better or worse results with specific plots:" ] }, { "cell_type": "code", "execution_count": null, "id": "c2107939-c6a9-414c-b3a2-6f5d0dd60daf", "metadata": {}, "outputs": [], "source": [ "p.facet([\"A\", \"B\"], [\"X\", \"Y\"]).layout(engine=\"constrained\")" ] }, { "cell_type": "markdown", "id": "d61054d1-dcef-4e11-9802-394bcc633f9f", "metadata": {}, "source": [ "With `extent`, you can control the size of the plot relative to the underlying figure. Because the notebook display adapts the figure background to the plot, this appears only to change the plot size in a notebook context. But it can be useful when saving or displaying through a `pyplot` GUI window:" ] }, { "cell_type": "code", "execution_count": null, "id": "1b5d5969-2925-474f-8e3c-99e4f90a7a2b", "metadata": {}, "outputs": [], "source": [ "p.layout(extent=[0, 0, .8, 1]).show()" ] }, { "cell_type": "code", "execution_count": null, "id": "e5c41b7d-a064-4406-8571-a544b194f3dc", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.limit.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so" ] }, { "cell_type": "raw", "id": "1888667e-8761-4c32-9510-68e08e64f21d", "metadata": {}, "source": [ "By default, plot limits are automatically set to provide a small margin around the data (controlled by :meth:`Plot.theme` parameters `axes.xmargin` and `axes.ymargin`):" ] }, { "cell_type": "code", "execution_count": null, "id": "25ec46d9-3c60-4962-b182-a2b2c8310305", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(x=[1, 2, 3], y=[1, 3, 2]).add(so.Line(marker=\"o\"))\n", "p" ] }, { "cell_type": "raw", "id": "5f5c19d8-4104-4df0-ae45-9a8ac96d024e", "metadata": {}, "source": [ "Pass a `min`/`max` tuple to pin the limits at specific values:" ] }, { "cell_type": "code", "execution_count": null, "id": "804388c5-5efa-4cfb-92d8-97fdf838ae5e", "metadata": {}, "outputs": [], "source": [ "p.limit(x=(0, 4), y=(-1, 6))" ] }, { "cell_type": "markdown", "id": "49634203-4c77-42ae-abc1-b182671f305e", "metadata": {}, "source": [ "Reversing the `min`/`max` values will invert the axis:" ] }, { "cell_type": "code", "execution_count": null, "id": "6ea1c82c-a9bc-43cc-ba75-5ee28923b8f2", "metadata": {}, "outputs": [], "source": [ "p.limit(y=(4, 0))" ] }, { "cell_type": "raw", "id": "9bb25c70-3960-4a81-891c-2bd299e7b24f", "metadata": {}, "source": [ "Use `None` for either side to maintain the default value:" ] }, { "cell_type": "code", "execution_count": null, "id": "d0566ba8-707c-4808-9a76-525ccaef7a42", "metadata": {}, "outputs": [], "source": [ "p.limit(y=(0, None))" ] }, { "cell_type": "code", "execution_count": null, "id": "fefc2b45-3510-4cd7-9de9-4806d71fc4c1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.on.ipynb ================================================ { "cells": [ { "cell_type": "raw", "id": "fb8e120d-5dcf-483b-a0d1-74857d09ce7d", "metadata": {}, "source": [ ".. currentmodule:: seaborn.objects" ] }, { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "%config InlineBackend.figure_format = \"retina\"\n", "import seaborn as sns\n", "import seaborn.objects as so\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "from seaborn import load_dataset\n", "diamonds = load_dataset(\"diamonds\")" ] }, { "cell_type": "raw", "id": "3445ed22-7a6a-4f91-8914-49bb1af023cb", "metadata": {}, "source": [ "Passing a :class:`matplotlib.axes.Axes` object provides functionality closest to seaborn's axes-level plotting functions. Notice how the resulting image looks different from others created with :class:`Plot`. This is because the plot theme uses the global rcParams at the time the axes were created, rather than :class:`Plot` defaults:" ] }, { "cell_type": "code", "execution_count": null, "id": "b816b0b1-b861-404e-bec6-9b2b0844ea5a", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(diamonds, \"carat\", \"price\").add(so.Dots())\n", "f, ax = plt.subplots()\n", "p.on(ax).show()" ] }, { "cell_type": "raw", "id": "ce3aa102-50fe-44ce-9e06-e25d14b410f1", "metadata": {}, "source": [ "Alternatively, calling :func:`matplotlib.pyplot.figure` will defer axes creation to :class:`Plot`, which will apply the default theme (and any customizations specified with :meth:`Plot.theme`):" ] }, { "cell_type": "code", "execution_count": null, "id": "52eefae9-d08e-48fb-a15b-27920609d53b", "metadata": {}, "outputs": [], "source": [ "f = plt.figure()\n", "p.on(f).show()" ] }, { "cell_type": "raw", "id": "171fa466-1f7a-4c5e-8a12-61edb3f11e4a", "metadata": {}, "source": [ "Creating a :class:`matplotlib.figure.Figure` object will bypass `pyplot` altogether. This may be useful for embedding :class:`Plot` figures in a GUI application:" ] }, { "cell_type": "code", "execution_count": null, "id": "bba83103-ab74-4e3c-b16e-77644f4c0431", "metadata": {}, "outputs": [], "source": [ "f = mpl.figure.Figure()\n", "p.on(f).plot()" ] }, { "cell_type": "raw", "id": "4cce3d40-acea-4f5c-87c4-56666480d2fe", "metadata": {}, "source": [ "Using :class:`Plot.on` also provides access to the underlying matplotlib objects, which may be useful for deep customization. But it requires a careful attention to the order of operations by which the :class:`Plot` is specified, compiled, customized, and displayed:" ] }, { "cell_type": "code", "execution_count": null, "id": "91823d24-8269-4b72-abeb-38201eb2db3f", "metadata": {}, "outputs": [], "source": [ "f = mpl.figure.Figure()\n", "res = p.on(f).plot()\n", "\n", "ax = f.axes[0]\n", "rect = mpl.patches.Rectangle(\n", " xy=(0, 1), width=.4, height=.1,\n", " color=\"C1\", alpha=.2,\n", " transform=ax.transAxes, clip_on=False,\n", ")\n", "ax.add_artist(rect)\n", "ax.text(\n", " x=rect.get_width() / 2, y=1 + rect.get_height() / 2,\n", " s=\"Diamonds: very sparkly!\", size=12,\n", " ha=\"center\", va=\"center\", transform=ax.transAxes,\n", ")\n", "\n", "res" ] }, { "cell_type": "raw", "id": "61286891-25b3-4db5-8ebe-af080d5c5f31", "metadata": {}, "source": [ "Matplotlib 3.4 introduced the concept of :meth:`matplotlib.figure.Figure.subfigures`, which make it easier to composite multiple arrangements of subplots. These can also be passed to :meth:`Plot.on`, " ] }, { "cell_type": "code", "execution_count": null, "id": "ca19a28e-7a49-46b3-a727-a26f4a1099c3", "metadata": {}, "outputs": [], "source": [ "f = mpl.figure.Figure(figsize=(7, 4), dpi=100, layout=\"constrained\")\n", "sf1, sf2 = f.subfigures(1, 2)\n", "\n", "p.on(sf1).plot()\n", "(\n", " so.Plot(diamonds, x=\"price\")\n", " .add(so.Bars(), so.Hist())\n", " .facet(row=\"cut\")\n", " .scale(x=\"log\")\n", " .share(y=False)\n", " .on(sf2)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "6ecd4166-939d-4925-92be-bf886a16ae94", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.pair.ipynb ================================================ { "cells": [ { "cell_type": "raw", "id": "ac7814b6-1e2c-4f0e-991b-7fe78fca4346", "metadata": {}, "source": [ ".. currentmodule:: seaborn.objects" ] }, { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "mpg = load_dataset(\"mpg\")" ] }, { "cell_type": "markdown", "id": "a6ee48da-ff1e-41eb-95ec-9f2dd12bdb63", "metadata": {}, "source": [ "Plot one dependent variable against multiple independent variables by assigning `y` and pairing on `x`:" ] }, { "cell_type": "code", "execution_count": null, "id": "56ab58b6-ccdf-4938-a8e0-cbe2de8d6749", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg, y=\"acceleration\")\n", " .pair(x=[\"displacement\", \"weight\"])\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "markdown", "id": "c37e0543-d022-4079-b58a-8f8af90b29c8", "metadata": {}, "source": [ "Show multiple pairwise relationships by passing lists to both `x` and `y`:" ] }, { "cell_type": "code", "execution_count": null, "id": "39b5298d-d578-4284-8fab-415d2c03022d", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg)\n", " .pair(x=[\"displacement\", \"weight\"], y=[\"horsepower\", \"acceleration\"])\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "markdown", "id": "09bf54ad-bf55-4e26-8566-5af62bf29c51", "metadata": {}, "source": [ "When providing lists for both `x` and `y`, pass `cross=False` to pair each position in the list rather than showing all pairwise relationships:" ] }, { "cell_type": "code", "execution_count": null, "id": "c70ca7d8-79ee-4c7a-ae91-2088e965b1f4", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg)\n", " .pair(\n", " x=[\"weight\", \"acceleration\"],\n", " y=[\"displacement\", \"horsepower\"],\n", " cross=False,\n", " )\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "markdown", "id": "79beadec-038d-40f0-8783-749474d48eac", "metadata": {}, "source": [ "When plotting against several `x` or `y` variables, it is possible to `wrap` the subplots to produce a two-dimensional grid:" ] }, { "cell_type": "code", "execution_count": null, "id": "2bf2d87f-a940-426c-bdff-8bf80696b7a1", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg, y=\"mpg\")\n", " .pair(x=[\"displacement\", \"weight\", \"horsepower\", \"cylinders\"], wrap=2)\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "markdown", "id": "6304faed-2466-49eb-a8c2-d9d635938b78", "metadata": {}, "source": [ "Pairing can be combined with faceting, either pairing on `y` and faceting on `col` or pairing on `x` and faceting on `row`:" ] }, { "cell_type": "code", "execution_count": null, "id": "bea235cd-e9c1-4119-a683-871e60b149ec", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg, x=\"weight\")\n", " .pair(y=[\"horsepower\", \"acceleration\"])\n", " .facet(col=\"origin\")\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "markdown", "id": "ded931d2-95f1-4e09-8e24-f8b687f8f052", "metadata": {}, "source": [ "While typically convenient to assign pairing variables as references to the common `data`, it's also possible to pass a list of vectors:" ] }, { "cell_type": "code", "execution_count": null, "id": "66e0cb77-094b-4144-b086-15bab106ca9f", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg[\"weight\"])\n", " .pair(y=[mpg[\"horsepower\"], mpg[\"acceleration\"]])\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "raw", "id": "7bef3310-87f6-44f6-be6a-e30effaa7a70", "metadata": {}, "source": [ "When customizing the plot through methods like :meth:`Plot.label`, :meth:`Plot.limit`, or :meth:`Plot.scale`, you can refer to the individual coordinate variables as `x0`, `x1`, etc.:" ] }, { "cell_type": "code", "execution_count": null, "id": "d6ce8868-55c0-4c44-8fed-937771b762ee", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(mpg, y=\"mpg\")\n", " .pair(x=[\"weight\", \"displacement\"])\n", " .label(x0=\"Weight (lb)\", x1=\"Displacement (cu in)\", y=\"MPG\")\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "358d409f-8b7c-4901-8eec-b2cf51731483", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.scale.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "diamonds = load_dataset(\"diamonds\")\n", "mpg = load_dataset(\"mpg\").query(\"cylinders in [4, 6, 8]\")" ] }, { "cell_type": "raw", "id": "bd43bcc6-b060-49c2-a429-8ea0ab046e2c", "metadata": {}, "source": [ "Passing the name of a function, such as `\"log\"` or `\"symlog\"` will set the scale's transform:" ] }, { "cell_type": "code", "execution_count": null, "id": "84b84cc1-ef1c-461e-b4af-4ce6e99886d1", "metadata": {}, "outputs": [], "source": [ "p1 = so.Plot(diamonds, x=\"carat\", y=\"price\")\n", "p1.add(so.Dots()).scale(y=\"log\")" ] }, { "cell_type": "raw", "id": "b5ea9f7f-c776-48af-a4be-0053c3c12036", "metadata": {}, "source": [ "String arguments can also specify the the name of a palette that defines the output values (or \"range\") of the scale:" ] }, { "cell_type": "code", "execution_count": null, "id": "e1f64d2f-6abd-48aa-9bab-c3e4614d0302", "metadata": {}, "outputs": [], "source": [ "p1.add(so.Dots(), color=\"clarity\").scale(color=\"crest\")" ] }, { "cell_type": "raw", "id": "37df8672-33b1-49a8-b702-a87c8b95db99", "metadata": {}, "source": [ "The scale's range can alternatively be specified as a tuple of min/max values:" ] }, { "cell_type": "code", "execution_count": null, "id": "371b8abd-ddfb-42f9-b730-f75b0e7b5fd6", "metadata": {}, "outputs": [], "source": [ "p1.add(so.Dots(), pointsize=\"carat\").scale(pointsize=(2, 10))" ] }, { "cell_type": "raw", "id": "f0c4ead3-e950-48e4-9c81-c8734a8458d0", "metadata": {}, "source": [ "The tuple format can also be used for a color scale:" ] }, { "cell_type": "code", "execution_count": null, "id": "678fd8b2-b031-4ec6-a567-a6711f722cbd", "metadata": {}, "outputs": [], "source": [ "p1.add(so.Dots(), color=\"carat\").scale(color=(\".4\", \"#68d\"))" ] }, { "cell_type": "raw", "id": "b6445ab7-2ec1-40be-95bc-9df0a5750bf5", "metadata": {}, "source": [ "For more control pass a scale object, such as :class:`Continuous`, which allows you to specify the input domain (`norm`), output range (`values`), and nonlinear transform (`trans`):" ] }, { "cell_type": "code", "execution_count": null, "id": "d6a219ef-b50e-442e-82e9-8ae9e2cdb825", "metadata": { "tags": [] }, "outputs": [], "source": [ "(\n", " p1.add(so.Dots(), color=\"carat\")\n", " .scale(color=so.Continuous((\".4\", \"#68d\"), norm=(1, 3), trans=\"sqrt\"))\n", ")" ] }, { "cell_type": "markdown", "id": "737e73a9-a0d5-4311-8c5c-4ca42f9194bf", "metadata": { "tags": [] }, "source": [ "The scale objects also offer an interface for configuring the location of the scale ticks (including in the legend) and the formatting of the tick labels:" ] }, { "cell_type": "code", "execution_count": null, "id": "cfaa426a-1a97-4b6f-91b6-ee378eabf194", "metadata": {}, "outputs": [], "source": [ "(\n", " p1.add(so.Dots(), color=\"price\")\n", " .scale(\n", " x=so.Continuous(trans=\"sqrt\").tick(every=.5),\n", " y=so.Continuous().label(like=\"${x:g}\"),\n", " color=so.Continuous(\"ch:.2\").tick(upto=4).label(unit=\"\"),\n", " )\n", " .label(y=\"\")\n", ")" ] }, { "cell_type": "raw", "id": "d4013795-fd5d-4a53-b145-e87f876a0684", "metadata": {}, "source": [ "If the scale includes a nonlinear transform, it will be applied *before* any statistical transforms:" ] }, { "cell_type": "code", "execution_count": null, "id": "e9bf321f-c482-4d25-bb3b-7c499930b0d1", "metadata": {}, "outputs": [], "source": [ "(\n", " p1.add(so.Dots(color=\".7\"))\n", " .add(so.Line(), so.PolyFit(order=2))\n", " .scale(y=\"log\")\n", " .limit(y=(250, 25000))\n", ")" ] }, { "cell_type": "raw", "id": "00ac5844-efb1-4683-a8ff-e864d0c68dff", "metadata": {}, "source": [ "The scale is also relevant for when numerical data should be treated as categories. Consider the following histogram:" ] }, { "cell_type": "code", "execution_count": null, "id": "04d5e6ae-30b2-495b-be1a-d99d6ffd4f44", "metadata": {}, "outputs": [], "source": [ "p2 = so.Plot(mpg, \"cylinders\").add(so.Bar(), so.Hist())\n", "p2" ] }, { "cell_type": "raw", "id": "9b3dafad-aae0-4862-b1b2-bb76b75a9cec", "metadata": {}, "source": [ "By default, the plot gives `cylinders` a continuous scale, since it is a vector of floats. But assigning a :class:`Nominal` scale causes the histogram to bin observations properly:" ] }, { "cell_type": "code", "execution_count": null, "id": "0f89331a-69fc-4714-adfb-0568690c1b66", "metadata": {}, "outputs": [], "source": [ "p2.scale(x=so.Nominal())" ] }, { "cell_type": "raw", "id": "78880057-f4a7-40a1-a619-20d4b3be34dc", "metadata": {}, "source": [ "The default behavior for semantic mappings also depends on input data types and can be modified by the scale. Consider the sequential mapping applied to the colors in this plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "653abbc6-8227-48eb-9e1d-31587e6ef46d", "metadata": {}, "outputs": [], "source": [ "p3 = (\n", " so.Plot(mpg, \"weight\", \"acceleration\", color=\"cylinders\")\n", " .add(so.Dot(), marker=\"origin\")\n", ")\n", "p3" ] }, { "cell_type": "raw", "id": "6ce5c9a8-5051-43b1-973c-fb9fb35ba399", "metadata": {}, "source": [ "Passing the name of a qualitative palette will select a :class:`Nominal` scale:" ] }, { "cell_type": "code", "execution_count": null, "id": "218d6619-1fe3-4412-a2fc-efed4f542db7", "metadata": {}, "outputs": [], "source": [ "p3.scale(color=\"deep\")" ] }, { "cell_type": "raw", "id": "d2362247-6e0e-48fb-bbe4-2149f96785ae", "metadata": {}, "source": [ "A :class:`Nominal` scale is also implied when the output values are given as a list or dictionary:" ] }, { "cell_type": "code", "execution_count": null, "id": "8bdf57da-cb05-4347-87ec-fac2c3763f12", "metadata": {}, "outputs": [], "source": [ "p3.scale(\n", " color=[\"#49b\", \"#a6a\", \"#5b8\"],\n", " marker={\"japan\": \".\", \"europe\": \"+\", \"usa\": \"*\"},\n", ")" ] }, { "cell_type": "raw", "id": "a7d92be7-9e96-4850-a26a-090c5ae9857b", "metadata": {}, "source": [ "Pass a :class:`Nominal` object directly to control the order of the category mappings:" ] }, { "cell_type": "code", "execution_count": null, "id": "a3c7eeb9-351f-484d-b0af-e18341569de3", "metadata": {}, "outputs": [], "source": [ "p3.scale(\n", " color=so.Nominal([\"#008fd5\", \"#fc4f30\", \"#e5ae38\"]),\n", " marker=so.Nominal(order=[\"japan\", \"europe\", \"usa\"])\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "d8885056-fd98-4964-a4a1-8c0344960409", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.share.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "3a874676-6b0d-45b1-a227-857a536c5ed2", "metadata": {}, "source": [ "By default, faceted plots will share all axes:" ] }, { "cell_type": "code", "execution_count": null, "id": "615d0765-98c7-4694-8115-a6d1b3557fe7", "metadata": {}, "outputs": [], "source": [ "p = (\n", " so.Plot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", " .facet(col=\"species\", row=\"sex\")\n", " .add(so.Dots())\n", ")\n", "p" ] }, { "cell_type": "raw", "id": "8b75feb1-491e-4031-9fcb-619037bd1bfb", "metadata": {}, "source": [ "Set a coordinate variable to `False` to let each subplot adapt independently:" ] }, { "cell_type": "code", "execution_count": null, "id": "4c23c570-ca9b-49cc-9aab-7d167218454b", "metadata": {}, "outputs": [], "source": [ "p.share(x=False, y=False)" ] }, { "cell_type": "markdown", "id": "cc46d8d0-7ab9-44c2-8a28-c656fe86c085", "metadata": {}, "source": [ "It's also possible to share only across rows or columns:" ] }, { "cell_type": "code", "execution_count": null, "id": "7cb8136b-9aa3-4c48-bd41-fc0e19fa997c", "metadata": {}, "outputs": [], "source": [ "p.share(x=\"col\", y=\"row\")" ] }, { "cell_type": "raw", "id": "91533aba-45ae-4011-b72c-10f5f79e01d0", "metadata": {}, "source": [ "This method is also relevant for paired plots, which have different defaults. In this case, you would need to opt *in* to full sharing (although it may not always make sense):" ] }, { "cell_type": "code", "execution_count": null, "id": "e2b71770-e520-45b9-b41c-a66431f21e1f", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, y=\"flipper_length_mm\")\n", " .pair(x=[\"bill_length_mm\", \"bill_depth_mm\"])\n", " .add(so.Dots())\n", " .share(x=True)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "92c29080-8561-4c90-8581-4d435a5f96b9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Plot.theme.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "9252d5a5-8af1-4f99-b799-ee044329fb23", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "anscombe = load_dataset(\"anscombe\")" ] }, { "cell_type": "raw", "id": "406f6608-daf2-4d3e-9f2c-1a9e93ecb840", "metadata": {}, "source": [ "The default theme uses the same parameters as :func:`seaborn.set_theme` with no additional arguments:" ] }, { "cell_type": "code", "execution_count": null, "id": "5e3d639c-1167-48d2-b9b5-c26b7fa12c66", "metadata": {}, "outputs": [], "source": [ "p = (\n", " so.Plot(anscombe, \"x\", \"y\", color=\"dataset\")\n", " .facet(\"dataset\", wrap=2)\n", " .add(so.Line(), so.PolyFit(order=1))\n", " .add(so.Dot())\n", ")\n", "p" ] }, { "cell_type": "raw", "id": "e2823a91-47f1-40a8-a150-32f00bcb59ea", "metadata": {}, "source": [ "Pass a dictionary of rc parameters to change the appearance of the plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "368c8cdb-2e6f-4520-8412-cd1864a6c09b", "metadata": {}, "outputs": [], "source": [ "p.theme({\"axes.facecolor\": \"w\", \"axes.edgecolor\": \"slategray\"})" ] }, { "cell_type": "raw", "id": "637cf0ba-e9b7-4f0f-a628-854e300c4122", "metadata": {}, "source": [ "Many (though not all) mark properties will reflect theme parameters by default:" ] }, { "cell_type": "code", "execution_count": null, "id": "9eb330b3-f424-405b-9653-5df9948792d9", "metadata": {}, "outputs": [], "source": [ "p.theme({\"lines.linewidth\": 4})" ] }, { "cell_type": "raw", "id": "0186e852-9c47-4da1-999a-f61f41687dfb", "metadata": {}, "source": [ "Apply seaborn styles by passing in the output of the style functions:" ] }, { "cell_type": "code", "execution_count": null, "id": "48cafbb1-37da-42c7-a20e-b63c0fef4d41", "metadata": {}, "outputs": [], "source": [ "from seaborn import axes_style\n", "p.theme(axes_style(\"ticks\"))" ] }, { "cell_type": "raw", "id": "bbdecb4b-382a-49f3-8928-16f5f72c39b5", "metadata": {}, "source": [ "Or apply styles that ship with matplotlib:" ] }, { "cell_type": "code", "execution_count": null, "id": "84a7ac28-798d-4560-bbc8-d214fd6fcada", "metadata": {}, "outputs": [], "source": [ "from matplotlib import style\n", "p.theme(style.library[\"fivethirtyeight\"])" ] }, { "cell_type": "raw", "id": "e1870ad0-48a0-4fd1-a557-d337979bc845", "metadata": {}, "source": [ "Multiple parameter dictionaries should be passed to the same function call. On Python 3.9+, you can use dictionary union syntax for this:" ] }, { "cell_type": "code", "execution_count": null, "id": "dec4db5b-1b2b-4b9d-97e1-9cf0f20d6b83", "metadata": {}, "outputs": [], "source": [ "from seaborn import plotting_context\n", "p.theme(axes_style(\"whitegrid\") | plotting_context(\"talk\"))" ] }, { "cell_type": "raw", "id": "7cc09720-887d-463e-a162-1e3ef8a46ad9", "metadata": {}, "source": [ "The default theme for all :class:`Plot` instances can be changed using the :attr:`Plot.config` attribute:" ] }, { "cell_type": "code", "execution_count": null, "id": "4e535ddf-d394-4ce1-8d09-4dc95ca314b4", "metadata": {}, "outputs": [], "source": [ "so.Plot.config.theme.update(axes_style(\"white\"))\n", "p" ] }, { "cell_type": "raw", "id": "2f19f645-3f8d-4044-82e9-4a87165a0078", "metadata": {}, "source": [ "See :ref:`Plot Configuration ` for more details." ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Range.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2923956c-f141-4ecb-ab08-e819099f0fa9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")" ] }, { "cell_type": "raw", "id": "576cbc86-f869-47b5-a98f-6ee727287a8b", "metadata": {}, "source": [ "This mark will often be used in the context of a stat transform that adds an errorbar interval:" ] }, { "cell_type": "code", "execution_count": null, "id": "f6217b85-7479-49fd-aeda-9f435aa0473a", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"body_mass_g\", y=\"species\", color=\"sex\")\n", " .add(so.Dot(), so.Agg(), so.Dodge())\n", " .add(so.Range(), so.Est(errorbar=\"sd\"), so.Dodge())\n", ")" ] }, { "cell_type": "raw", "id": "e156ea24-d8b4-4d67-acb5-750034be4dde", "metadata": {}, "source": [ "One feature (or potential gotcha) is that the mark will pick up properties like `linestyle` and `linewidth`; exclude those properties from the relevant layer if this behavior is undesired:" ] }, { "cell_type": "code", "execution_count": null, "id": "4bb63ebb-7733-4313-844c-cb7613298da3", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"sex\", y=\"body_mass_g\", linestyle=\"species\")\n", " .facet(\"species\")\n", " .add(so.Line(marker=\"o\"), so.Agg())\n", " .add(so.Range(), so.Est(errorbar=\"sd\"))\n", ")" ] }, { "cell_type": "raw", "id": "5387e049-b343-49ea-a943-7dd9c090f184", "metadata": {}, "source": [ "It's also possible to directly assign the minimum and maximum values for the range:" ] }, { "cell_type": "code", "execution_count": null, "id": "4e795770-4481-4e23-a49b-e828a1f5cbbd", "metadata": {}, "outputs": [], "source": [ "(\n", " penguins\n", " .rename_axis(index=\"penguin\")\n", " .pipe(so.Plot, x=\"penguin\", ymin=\"bill_depth_mm\", ymax=\"bill_length_mm\")\n", " .add(so.Range(), color=\"island\")\n", ")" ] }, { "cell_type": "markdown", "id": "2191bec6-a02e-48e0-b92c-69c38826049d", "metadata": {}, "source": [ "When `min`/`max` variables are neither computed as part of a transform or explicitly assigned, the range will cover the full extent of the data at each unique observation on the orient axis:" ] }, { "cell_type": "code", "execution_count": null, "id": "63c6352e-4ef5-4cff-940e-35fa5804b2c7", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"sex\", y=\"body_mass_g\")\n", " .facet(\"species\")\n", " .add(so.Dots(pointsize=6))\n", " .add(so.Range(linewidth=2))\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "c215deb1-e510-4631-b999-737f5f41cae2", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Shift.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "2605c8d0-5872-4dff-9172-db81fac1cee1", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "penguins = load_dataset(\"penguins\")\n", "diamonds = load_dataset(\"diamonds\")" ] }, { "cell_type": "raw", "id": "e70d701a-cd7c-4b38-aaa0-4729e2be56d9", "metadata": {}, "source": [ "Use this transform to layer multiple marks that would otherwise overlap and be hard to interpret:" ] }, { "cell_type": "code", "execution_count": null, "id": "5ea7a2c4-cb69-4ad0-8ea8-73067b756371", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, \"species\", \"body_mass_g\")\n", " .add(so.Dots(), so.Jitter())\n", " .add(so.Range(), so.Perc([25, 75]), so.Shift(x=.2))\n", ")" ] }, { "cell_type": "raw", "id": "940b87b2-04fb-40ba-a62f-52f461039ab9", "metadata": {}, "source": [ "For y variables with a nominal scale, bear in mind that the axis will be inverted and a positive shift will move downwards:" ] }, { "cell_type": "code", "execution_count": null, "id": "54b5f728-4fbc-474a-8865-0f58d0ad9b0b", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(diamonds, \"carat\", \"clarity\")\n", " .add(so.Dots(), so.Jitter())\n", " .add(so.Range(), so.Perc([25, 75]), so.Shift(y=.25))\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "78d9bb6a-ea3d-491e-b43e-25efd386bd59", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Stack.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "87244f49-8cf2-4668-a556-a8c7828b31bf", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "titanic = load_dataset(\"titanic\").sort_values(\"alive\", ascending=False)" ] }, { "cell_type": "raw", "id": "c9a1a7db-f365-4c5f-85ae-1f00e15b0af9", "metadata": {}, "source": [ "This transform applies a vertical shift to eliminate overlap between marks with a baseline, such as :class:`Bar` or :class:`Area`:" ] }, { "cell_type": "code", "execution_count": null, "id": "07579f71-842d-4dc1-98ab-38652409238d", "metadata": {}, "outputs": [], "source": [ "so.Plot(titanic, x=\"class\", color=\"sex\").add(so.Bar(), so.Count(), so.Stack())" ] }, { "cell_type": "raw", "id": "2488a821-3bf1-4bb9-9963-bf726d11925c", "metadata": {}, "source": [ "Stacking can make it much harder to compare values between groups that get shifted, but it can work well when depicting a part-whole relationship:" ] }, { "cell_type": "code", "execution_count": null, "id": "dcb8ea58-3cf2-455b-b6b7-98b434f2f152", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(titanic, x=\"age\", alpha=\"alive\")\n", " .facet(\"sex\")\n", " .add(so.Bars(), so.Hist(binwidth=10), so.Stack())\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "b649198f-898e-4103-84bc-d74de71de5a7", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/objects.Text.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "cd1cdefe-b8c1-40b9-be31-006d52ec9f18", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn.objects as so\n", "from seaborn import load_dataset\n", "glue = (\n", " load_dataset(\"glue\")\n", " .pivot(index=[\"Model\", \"Encoder\"], columns=\"Task\", values=\"Score\")\n", " .assign(Average=lambda x: x.mean(axis=1).round(1))\n", " .sort_values(\"Average\", ascending=False)\n", ")" ] }, { "cell_type": "raw", "id": "3e49ffb1-8778-4cd5-80d6-9d7e1438bc9c", "metadata": {}, "source": [ "Add text at x/y locations on the plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "3bf21068-d39e-436c-8deb-aa1b15aeb2b3", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(glue, x=\"SST-2\", y=\"MRPC\", text=\"Model\")\n", " .add(so.Text())\n", ")" ] }, { "cell_type": "raw", "id": "a4b9a8b2-6603-46db-9ede-3b3fb45e0e64", "metadata": {}, "source": [ "Add bar annotations, horizontally-aligned with `halign`:" ] }, { "cell_type": "code", "execution_count": null, "id": "f68501f0-c868-439e-9485-d71cca86ea47", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(glue, x=\"Average\", y=\"Model\", text=\"Average\")\n", " .add(so.Bar())\n", " .add(so.Text(color=\"w\", halign=\"right\"))\n", ")" ] }, { "cell_type": "raw", "id": "a9d39479-0afa-477b-8403-fe92a54643c9", "metadata": {}, "source": [ "Fine-tune the alignment using `offset`:" ] }, { "cell_type": "code", "execution_count": null, "id": "b5da4a9d-79f3-4c11-bab3-f89da8512ce4", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(glue, x=\"Average\", y=\"Model\", text=\"Average\")\n", " .add(so.Bar())\n", " .add(so.Text(color=\"w\", halign=\"right\", offset=6))\n", ")" ] }, { "cell_type": "raw", "id": "e9c43798-70d5-42b5-bd91-b85684d1b671", "metadata": {}, "source": [ "Add text above dots, mapping the text color with a third variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "b2d26ebc-24ac-4531-9ba2-fa03720c58bc", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(glue, x=\"SST-2\", y=\"MRPC\", color=\"Encoder\", text=\"Model\")\n", " .add(so.Dot())\n", " .add(so.Text(valign=\"bottom\"))\n", "\n", ")" ] }, { "cell_type": "raw", "id": "f31aaa38-6728-4299-8422-8762c52c9857", "metadata": {}, "source": [ "Map the text alignment for better use of space:" ] }, { "cell_type": "code", "execution_count": null, "id": "cf4bbf0c-0c5f-4c31-b971-720ea8910918", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(glue, x=\"RTE\", y=\"MRPC\", color=\"Encoder\", text=\"Model\")\n", " .add(so.Dot())\n", " .add(so.Text(), halign=\"Encoder\")\n", " .scale(halign={\"LSTM\": \"left\", \"Transformer\": \"right\"})\n", ")" ] }, { "cell_type": "raw", "id": "a5de35a6-1ccf-4958-8013-edd9ed1cd4b0", "metadata": {}, "source": [ "Use additional matplotlib parameters to control the appearance of the text:" ] }, { "cell_type": "code", "execution_count": null, "id": "9c4be188-1614-4c19-9bd7-b07e986f6a23", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(glue, x=\"RTE\", y=\"MRPC\", color=\"Encoder\", text=\"Model\")\n", " .add(so.Dot())\n", " .add(so.Text({\"fontweight\": \"bold\"}), halign=\"Encoder\")\n", " .scale(halign={\"LSTM\": \"left\", \"Transformer\": \"right\"})\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "95fb7aee-090a-4415-917c-b5258d2b298b", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/pairplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"ticks\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The simplest invocation uses :func:`scatterplot` for each pairing of the variables and :func:`histplot` for the marginal plots along the diagonal:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.pairplot(penguins)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a ``hue`` variable adds a semantic mapping and changes the default marginal plot to a layered kernel density estimate (KDE):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins, hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's possible to force marginal histograms:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins, hue=\"species\", diag_kind=\"hist\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ``kind`` parameter determines both the diagonal and off-diagonal plotting style. Several options are available, including using :func:`kdeplot` to draw KDEs:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins, kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Or :func:`histplot` to draw both bivariate and univariate histograms:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins, kind=\"hist\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ``markers`` parameter applies a style mapping on the off-diagonal axes. Currently, it will be redundant with the ``hue`` variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins, hue=\"species\", markers=[\"o\", \"s\", \"D\"])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As with other figure-level functions, the size of the figure is controlled by setting the ``height`` of each individual subplot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins, height=1.5)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use ``vars`` or ``x_vars`` and ``y_vars`` to select the variables to plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(\n", " penguins,\n", " x_vars=[\"bill_length_mm\", \"bill_depth_mm\", \"flipper_length_mm\"],\n", " y_vars=[\"bill_length_mm\", \"bill_depth_mm\"],\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Set ``corner=True`` to plot only the lower triangle:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins, corner=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ``plot_kws`` and ``diag_kws`` parameters accept dicts of keyword arguments to customize the off-diagonal and diagonal plots, respectively:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(\n", " penguins,\n", " plot_kws=dict(marker=\"+\", linewidth=1),\n", " diag_kws=dict(fill=False),\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The return object is the underlying :class:`PairGrid`, which can be used to further customize the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.pairplot(penguins, diag_kind=\"kde\")\n", "g.map_lower(sns.kdeplot, levels=4, color=\".2\")" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/plotting_context.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "perceived-worry", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns" ] }, { "cell_type": "markdown", "id": "seventh-volleyball", "metadata": {}, "source": [ "Calling with no arguments will return the current defaults for the parameters that get scaled:" ] }, { "cell_type": "code", "execution_count": null, "id": "roman-villa", "metadata": { "tags": [ "show-output" ] }, "outputs": [], "source": [ "sns.plotting_context()" ] }, { "cell_type": "markdown", "id": "handled-texas", "metadata": {}, "source": [ "Calling with the name of a predefined style will show those values:" ] }, { "cell_type": "code", "execution_count": null, "id": "distant-caribbean", "metadata": { "tags": [ "show-output" ] }, "outputs": [], "source": [ "sns.plotting_context(\"talk\")" ] }, { "cell_type": "markdown", "id": "lightweight-anime", "metadata": {}, "source": [ "Use the function as a context manager to temporarily change the parameter values:" ] }, { "cell_type": "code", "execution_count": null, "id": "contemporary-hampshire", "metadata": {}, "outputs": [], "source": [ "with sns.plotting_context(\"talk\"):\n", " sns.lineplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "code", "execution_count": null, "id": "accompanied-brisbane", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/pointplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "43f842ee-44c9-476b-ab08-112d23e2effb", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")\n", "penguins = sns.load_dataset(\"penguins\")\n", "flights = sns.load_dataset(\"flights\")" ] }, { "cell_type": "raw", "id": "f25d3647-9fad-47b2-b49d-db6f5b5c3795", "metadata": {}, "source": [ "Group by a categorical variable and plot aggregated values, with confidence intervals:" ] }, { "cell_type": "code", "execution_count": null, "id": "9a865fec-c034-4000-938d-b7cd89157495", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(data=penguins, x=\"island\", y=\"body_mass_g\")" ] }, { "cell_type": "raw", "id": "c65257ad-c87f-4b78-9b6c-cf792a691598", "metadata": {}, "source": [ "Add a second layer of grouping and differentiate with color:" ] }, { "cell_type": "code", "execution_count": null, "id": "f27011f1-0e3c-4dc4-818e-4a77930977b9", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(data=penguins, x=\"island\", y=\"body_mass_g\", hue=\"sex\")" ] }, { "cell_type": "raw", "id": "d51a887c-1f64-4ddf-af31-0476a983818b", "metadata": {}, "source": [ "Redundantly code the `hue` variable using the markers and linestyles for better accessibility:" ] }, { "cell_type": "code", "execution_count": null, "id": "1bfb8bc1-6f9a-49a1-8b1d-6bcc992cb249", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(\n", " data=penguins,\n", " x=\"island\", y=\"body_mass_g\", hue=\"sex\",\n", " markers=[\"o\", \"s\"], linestyles=[\"-\", \"--\"],\n", ")" ] }, { "cell_type": "raw", "id": "44a11b7a-6847-4225-906e-58bbb56c6966", "metadata": {}, "source": [ "Use the error bars to represent the standard deviation of each distribution:" ] }, { "cell_type": "code", "execution_count": null, "id": "386b25eb-7ab7-4a1d-9498-cef3e4fd3e6b", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(data=penguins, x=\"island\", y=\"body_mass_g\", errorbar=\"sd\")" ] }, { "cell_type": "raw", "id": "7490d4b8-d2ca-4cad-9ba3-5862aafb8165", "metadata": {}, "source": [ "Customize the appearance of the plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "50b14810-2299-479c-b6c5-0fd10c4ed3de", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(\n", " data=penguins, x=\"body_mass_g\", y=\"island\",\n", " errorbar=(\"pi\", 100), capsize=.4,\n", " color=\".5\", linestyle=\"none\", marker=\"D\",\n", ")" ] }, { "cell_type": "raw", "id": "479e4e0c-42c9-4d79-88eb-e397840a7e78", "metadata": {}, "source": [ "\"Dodge\" the artists along the categorical axis to reduce overplotting:" ] }, { "cell_type": "code", "execution_count": null, "id": "8f94d069-c5f4-4579-a4bf-6d755962d48d", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(data=penguins, x=\"sex\", y=\"bill_depth_mm\", hue=\"species\", dodge=True)" ] }, { "cell_type": "raw", "id": "00273ada-cd12-410a-a268-38243d6514ae", "metadata": {}, "source": [ "Dodge by a specific amount, relative to the width allotted for each level:" ] }, { "cell_type": "code", "execution_count": null, "id": "94d6718d-2cfe-44f4-88e5-f47461d7d51f", "metadata": {}, "outputs": [], "source": [ "sns.stripplot(\n", " data=penguins, x=\"species\", y=\"bill_depth_mm\", hue=\"sex\",\n", " dodge=True, alpha=.2, legend=False,\n", ")\n", "sns.pointplot(\n", " data=penguins, x=\"species\", y=\"bill_depth_mm\", hue=\"sex\",\n", " dodge=.4, linestyle=\"none\", errorbar=None,\n", " marker=\"_\", markersize=20, markeredgewidth=3,\n", ")" ] }, { "cell_type": "raw", "id": "e205e7c8-6b11-44e6-b43f-7416c427215d", "metadata": {}, "source": [ "When variables are not explicitly assigned and the dataset is two-dimensional, the plot will aggregate over each column:" ] }, { "cell_type": "code", "execution_count": null, "id": "e721e3b7-25c8-4e9c-a748-1c36b06d1100", "metadata": {}, "outputs": [], "source": [ "flights_wide = flights.pivot(index=\"year\", columns=\"month\", values=\"passengers\")\n", "sns.pointplot(flights_wide)" ] }, { "cell_type": "raw", "id": "0d2d7811-06e3-4882-86e3-225071c864f7", "metadata": {}, "source": [ "With one-dimensional data, each value is plotted (relative to its key or index when available):" ] }, { "cell_type": "code", "execution_count": null, "id": "acd91ddc-2a27-4b05-80fa-00ddcf1ae63e", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(flights_wide[\"Jun\"])" ] }, { "cell_type": "raw", "id": "573c2ba7-1e46-494d-9076-19b1c04b58c1", "metadata": {}, "source": [ "Control the formatting of the categorical variable as it appears in the tick labels:" ] }, { "cell_type": "code", "execution_count": null, "id": "7af6ab85-bfb1-42c2-8c68-2c91f22968d6", "metadata": {}, "outputs": [], "source": [ "sns.pointplot(flights_wide[\"Jun\"], formatter=lambda x: f\"'{x % 1900}\")" ] }, { "cell_type": "raw", "id": "c319e82e-1387-4c2b-8daf-3b7174cad180", "metadata": {}, "source": [ "Or preserve the native scale of the grouping variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "e92e8af9-b734-4e4d-a240-7f3982fcfbcc", "metadata": {}, "outputs": [], "source": [ "ax = sns.pointplot(flights_wide[\"Jun\"], native_scale=True)\n", "ax.plot(1955, 335, marker=\"*\", color=\"r\", markersize=10)" ] }, { "cell_type": "code", "execution_count": null, "id": "0f88be5c-7919-48cf-a84f-a5e6ac86e888", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/regplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "611aed40-d120-4fbf-b1e6-9712ed8167fc", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "sns.set_theme()\n", "mpg = sns.load_dataset(\"mpg\")" ] }, { "cell_type": "raw", "id": "61bebade-0c45-4e99-9567-dfe0bc2dc6e1", "metadata": {}, "source": [ "Plot the relationship between two variables in a DataFrame:" ] }, { "cell_type": "code", "execution_count": null, "id": "2f4107db-d89b-46ad-a4c6-9ba1181b2122", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"weight\", y=\"acceleration\")" ] }, { "cell_type": "raw", "id": "146225d0-2e38-4b92-8e64-6d7f78311f40", "metadata": {}, "source": [ "Fit a higher-order polynomial regression to capture nonlinear trends:" ] }, { "cell_type": "code", "execution_count": null, "id": "ba29488c-8a45-4387-bfb1-71a584fa1b3d", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"weight\", y=\"mpg\", order=2)" ] }, { "cell_type": "raw", "id": "0ad71f54-b362-465e-8780-1d8b99ff2d51", "metadata": {}, "source": [ "Alternatively, fit a log-linear regression:" ] }, { "cell_type": "code", "execution_count": null, "id": "aae2acaa-ed07-4568-97d2-8665603eb7eb", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"displacement\", y=\"mpg\", logx=True)" ] }, { "cell_type": "raw", "id": "eef37c8a-7190-465c-b963-076ec17e1b3a", "metadata": {}, "source": [ "Or use a locally-weighted (LOWESS) smoother:" ] }, { "cell_type": "code", "execution_count": null, "id": "9276c469-72ea-4c36-9b7c-19ecba564376", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"horsepower\", y=\"mpg\", lowess=True)" ] }, { "cell_type": "raw", "id": "d18f1534-598e-4f08-91dd-0c4020f30b00", "metadata": {}, "source": [ "Fit a logistic regression when the response variable is binary:" ] }, { "cell_type": "code", "execution_count": null, "id": "79ec9180-10c9-4910-9713-dcd1fdd266be", "metadata": {}, "outputs": [], "source": [ "sns.regplot(x=mpg[\"weight\"], y=mpg[\"origin\"].eq(\"usa\").rename(\"from_usa\"), logistic=True)" ] }, { "cell_type": "raw", "id": "2e165783-d505-4acb-a20a-d22a49965c2b", "metadata": {}, "source": [ "Fit a robust regression to downweight the influence of outliers:" ] }, { "cell_type": "code", "execution_count": null, "id": "fd5cf940-de8f-4230-8b04-5c650418f3c4", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"horsepower\", y=\"weight\", robust=True)" ] }, { "cell_type": "raw", "id": "e7d43c4e-e819-4634-8269-cbf5de4a2f24", "metadata": {}, "source": [ "Disable the confidence interval for faster plotting:" ] }, { "cell_type": "code", "execution_count": null, "id": "b21384ff-6395-4fa9-b7da-63e8a951d8a5", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"weight\", y=\"horsepower\", ci=None)" ] }, { "cell_type": "raw", "id": "06e979ac-f418-4ead-bde1-ec684d0545ff", "metadata": {}, "source": [ "Jitter the scatterplot when the `x` variable is discrete:" ] }, { "cell_type": "code", "execution_count": null, "id": "543a8ace-a89e-4af9-bf6d-a8722ebdfac5", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"cylinders\", y=\"weight\", x_jitter=.15)" ] }, { "cell_type": "raw", "id": "c3042eb2-0933-4886-9bff-88c276371516", "metadata": {}, "source": [ "Or aggregate over the distinct `x` values:" ] }, { "cell_type": "code", "execution_count": null, "id": "158c6e36-8858-415b-b78c-7d8d79879ee5", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"cylinders\", y=\"acceleration\", x_estimator=np.mean, order=2)" ] }, { "cell_type": "raw", "id": "d9cefe7a-7f86-4353-95da-d7e72e65d4fc", "metadata": {}, "source": [ "With a continuous `x` variable, bin and then aggregate:" ] }, { "cell_type": "code", "execution_count": null, "id": "1c48829b-2e3b-4e6b-9b1d-5ba69f713617", "metadata": {}, "outputs": [], "source": [ "sns.regplot(data=mpg, x=\"weight\", y=\"mpg\", x_bins=np.arange(2000, 5500, 250), order=2)" ] }, { "cell_type": "raw", "id": "dfe5a36a-20b0-4e69-b986-fede8e1506cc", "metadata": {}, "source": [ "Customize the appearance of various elements:" ] }, { "cell_type": "code", "execution_count": null, "id": "df689a39-c5e1-4f7b-a8f9-8ffb09b95238", "metadata": {}, "outputs": [], "source": [ "sns.regplot(\n", " data=mpg, x=\"weight\", y=\"horsepower\",\n", " ci=99, marker=\"x\", color=\".3\", line_kws=dict(color=\"r\"),\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "d625745b-3706-447b-9224-88e6cb1eb7f9", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/relplot.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ "These examples will illustrate only some of the functionality that :func:`relplot` is capable of. For more information, consult the examples for :func:`scatterplot` and :func:`lineplot`, which are used when ``kind=\"scatter\"`` or ``kind=\"line\"``, respectively." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set_theme(style=\"ticks\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To illustrate ``kind=\"scatter\"`` (the default style of plot), we will use the \"tips\" dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "tips.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning ``x`` and ``y`` and any semantic mapping variables will draw a single plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a ``col`` variable creates a faceted figure with multiple subplots arranged across the columns of the grid:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\", col=\"time\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Different variables can be assigned to facet on both the columns and rows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\", col=\"time\", row=\"sex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When the variable assigned to ``col`` has many levels, it can be \"wrapped\" across multiple rows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\", col=\"day\", col_wrap=2)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning multiple semantic variables can show multi-dimensional relationships, but be mindful to avoid making an overly-complicated plot." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips, x=\"total_bill\", y=\"tip\", col=\"time\",\n", " hue=\"time\", size=\"size\", style=\"sex\",\n", " palette=[\"b\", \"r\"], sizes=(10, 100)\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When there is a natural continuity to one of the variables, it makes more sense to show lines instead of points. To draw the figure using :func:`lineplot`, set ``kind=\"line\"``. We will illustrate this effect with the \"fmri dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fmri = sns.load_dataset(\"fmri\")\n", "fmri.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Using ``kind=\"line\"`` offers the same flexibility for semantic mappings as ``kind=\"scatter\"``, but :func:`lineplot` transforms the data more before plotting. Observations are sorted by their ``x`` value, and repeated observations are aggregated. By default, the resulting plot shows the mean and 95% CI for each unit" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, x=\"timepoint\", y=\"signal\", col=\"region\",\n", " hue=\"event\", style=\"event\", kind=\"line\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The size and shape of the figure is parametrized by the ``height`` and ``aspect`` ratio of each individual facet:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri,\n", " x=\"timepoint\", y=\"signal\",\n", " hue=\"event\", style=\"event\", col=\"region\",\n", " height=4, aspect=.7, kind=\"line\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The object returned by :func:`relplot` is always a :class:`FacetGrid`, which has several methods that allow you to quickly tweak the title, labels, and other aspects of the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.relplot(\n", " data=fmri,\n", " x=\"timepoint\", y=\"signal\",\n", " hue=\"event\", style=\"event\", col=\"region\",\n", " height=4, aspect=.7, kind=\"line\"\n", ")\n", "(g.map(plt.axhline, y=0, color=\".7\", dashes=(2, 1), zorder=0)\n", " .set_axis_labels(\"Timepoint\", \"Percent signal change\")\n", " .set_titles(\"Region: {col_name} cortex\")\n", " .tight_layout(w_pad=0))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It is also possible to use wide-form data with :func:`relplot`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights_wide = (\n", " sns.load_dataset(\"flights\")\n", " .pivot(index=\"year\", columns=\"month\", values=\"passengers\")\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Faceting is not an option in this case, but the plot will still take advantage of the external legend offered by :class:`FacetGrid`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=flights_wide, kind=\"line\")" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/residplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "776f8271-21ed-4707-a1ad-09d8c63ae95a", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme()\n", "mpg = sns.load_dataset(\"mpg\")" ] }, { "cell_type": "raw", "id": "85717971-adc9-45b0-9c4b-3f022d96179c", "metadata": {}, "source": [ "Pass `x` and `y` to see a scatter plot of the residuals after fitting a simple regression model:" ] }, { "cell_type": "code", "execution_count": null, "id": "5aea4655-fb51-4b51-b41d-4769de50e956", "metadata": {}, "outputs": [], "source": [ "sns.residplot(data=mpg, x=\"weight\", y=\"displacement\")" ] }, { "cell_type": "raw", "id": "175b6287-9240-493f-94bc-9d18258e952b", "metadata": {}, "source": [ "Structure in the residual plot can reveal a violation of linear regression assumptions:" ] }, { "cell_type": "code", "execution_count": null, "id": "39aa84c2-d623-44be-9b0b-746f52b55fd4", "metadata": {}, "outputs": [], "source": [ "sns.residplot(data=mpg, x=\"horsepower\", y=\"mpg\")" ] }, { "cell_type": "raw", "id": "bd9641e4-8df5-4751-b261-6443888fbbfe", "metadata": {}, "source": [ "Remove higher-order trends to test whether that stabilizes the residuals:" ] }, { "cell_type": "code", "execution_count": null, "id": "03a68199-1272-464b-8b85-7a309c22a4a6", "metadata": {}, "outputs": [], "source": [ "sns.residplot(data=mpg, x=\"horsepower\", y=\"mpg\", order=2)" ] }, { "cell_type": "raw", "id": "b17750af-0393-4c53-8057-bf95d0de821a", "metadata": {}, "source": [ "Adding a LOWESS curve can help reveal or emphasize structure:" ] }, { "cell_type": "code", "execution_count": null, "id": "494359bd-47b2-426e-9c35-14b5351eec93", "metadata": {}, "outputs": [], "source": [ "sns.residplot(data=mpg, x=\"horsepower\", y=\"mpg\", lowess=True, line_kws=dict(color=\"r\"))" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/rugplot.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Add a rug along one of the axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import seaborn as sns; sns.set_theme()\n", "tips = sns.load_dataset(\"tips\")\n", "sns.kdeplot(data=tips, x=\"total_bill\")\n", "sns.rugplot(data=tips, x=\"total_bill\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Add a rug along both axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")\n", "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Represent a third variable with hue mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\")\n", "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Draw a taller rug:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")\n", "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\", height=.1)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Put the rug outside the axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")\n", "sns.rugplot(data=tips, x=\"total_bill\", y=\"tip\", height=-.02, clip_on=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Show the density of a larger dataset using thinner lines and alpha blending:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "diamonds = sns.load_dataset(\"diamonds\")\n", "sns.scatterplot(data=diamonds, x=\"carat\", y=\"price\", s=5)\n", "sns.rugplot(data=diamonds, x=\"carat\", y=\"price\", lw=1, alpha=.005)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/scatterplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "These examples will use the \"tips\" dataset, which has a mixture of numeric and categorical variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "tips.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Passing long-form data and assigning ``x`` and ``y`` will draw a scatter plot between two variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a variable to ``hue`` will map its levels to the color of the points:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning the same variable to ``style`` will also vary the markers and create a more accessible plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"time\", style=\"time\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning ``hue`` and ``style`` to different variables will vary colors and markers independently:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"day\", style=\"time\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If the variable assigned to ``hue`` is numeric, the semantic mapping will be quantitative and use a different default palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Pass the name of a categorical palette or explicit colors (as a Python list of dictionary) to force categorical mapping of the ``hue`` variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", palette=\"deep\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If there are a large number of unique numeric values, the legend will show a representative, evenly-spaced set:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tip_rate = tips.eval(\"tip / total_bill\").rename(\"tip_rate\")\n", "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=tip_rate)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A numeric variable can also be assigned to ``size`` to apply a semantic mapping to the areas of the points:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", size=\"size\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Control the range of marker areas with ``sizes``, and set ``legend=\"full\"`` to force every unique value to appear in the legend:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(\n", " data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", size=\"size\",\n", " sizes=(20, 200), legend=\"full\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Pass a tuple of values or a :class:`matplotlib.colors.Normalize` object to ``hue_norm`` to control the quantitative hue mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(\n", " data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\", size=\"size\",\n", " sizes=(20, 200), hue_norm=(0, 7), legend=\"full\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Control the specific markers used to map the ``style`` variable by passing a Python list or dictionary of marker codes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "markers = {\"Lunch\": \"s\", \"Dinner\": \"X\"}\n", "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", style=\"time\", markers=markers)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Additional keyword arguments are passed to :meth:`matplotlib.axes.Axes.scatter`, allowing you to directly set the attributes of the plot that are not semantically mapped:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.scatterplot(data=tips, x=\"total_bill\", y=\"tip\", s=100, color=\".2\", marker=\"+\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The previous examples used a long-form dataset. When working with wide-form data, each column will be plotted against its index using both ``hue`` and ``style`` mapping:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "index = pd.date_range(\"1 1 2000\", periods=100, freq=\"ME\", name=\"date\")\n", "data = np.random.randn(100, 4).cumsum(axis=0)\n", "wide_df = pd.DataFrame(data, index, [\"a\", \"b\", \"c\", \"d\"])\n", "sns.scatterplot(data=wide_df)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use :func:`relplot` to combine :func:`scatterplot` and :class:`FacetGrid`. This allows grouping within additional categorical variables, and plotting them across multiple subplots.\n", "\n", "Using :func:`relplot` is safer than using :class:`FacetGrid` directly, as it ensures synchronization of the semantic mappings across facets." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips, x=\"total_bill\", y=\"tip\",\n", " col=\"time\", hue=\"day\", style=\"day\",\n", " kind=\"scatter\"\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/set_context.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "thorough-equipment", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns" ] }, { "cell_type": "markdown", "id": "canadian-protection", "metadata": {}, "source": [ "Call the function with the name of a context to set the default for all plots:" ] }, { "cell_type": "code", "execution_count": null, "id": "freelance-leonard", "metadata": {}, "outputs": [], "source": [ "sns.set_context(\"notebook\")\n", "sns.lineplot(x=[0, 1, 2], y=[1, 3, 2])" ] }, { "cell_type": "markdown", "id": "studied-adventure", "metadata": {}, "source": [ "You can independently scale the font elements relative to the current context:" ] }, { "cell_type": "code", "execution_count": null, "id": "irish-digest", "metadata": {}, "outputs": [], "source": [ "sns.set_context(\"notebook\", font_scale=1.25)\n", "sns.lineplot(x=[0, 1, 2], y=[1, 3, 2])" ] }, { "cell_type": "markdown", "id": "fourth-technical", "metadata": {}, "source": [ "It is also possible to override some of the parameters with specific values:" ] }, { "cell_type": "code", "execution_count": null, "id": "advance-request", "metadata": {}, "outputs": [], "source": [ "sns.set_context(\"notebook\", rc={\"lines.linewidth\": 3})\n", "sns.lineplot(x=[0, 1, 2], y=[1, 3, 2])" ] }, { "cell_type": "code", "execution_count": null, "id": "compatible-string", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/set_style.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "practical-announcement", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns" ] }, { "cell_type": "markdown", "id": "suffering-emerald", "metadata": {}, "source": [ "Call the function with the name of a seaborn style to set the default for all plots:" ] }, { "cell_type": "code", "execution_count": null, "id": "collaborative-struggle", "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"whitegrid\")\n", "sns.barplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "markdown", "id": "defensive-surgery", "metadata": {}, "source": [ "You can also selectively override seaborn's default parameter values:" ] }, { "cell_type": "code", "execution_count": null, "id": "coastal-sydney", "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"darkgrid\", {\"grid.color\": \".6\", \"grid.linestyle\": \":\"})\n", "sns.lineplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "code", "execution_count": null, "id": "bright-october", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/set_theme.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "flush-block", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "markdown", "id": "remarkable-confirmation", "metadata": {}, "source": [ "By default, seaborn plots will be made with the current values of the matplotlib rcParams:" ] }, { "cell_type": "code", "execution_count": null, "id": "viral-highway", "metadata": {}, "outputs": [], "source": [ "sns.barplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "markdown", "id": "hungarian-poster", "metadata": {}, "source": [ "Calling this function with no arguments will activate seaborn's \"default\" theme:" ] }, { "cell_type": "code", "execution_count": null, "id": "front-february", "metadata": {}, "outputs": [], "source": [ "sns.set_theme()\n", "sns.barplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "markdown", "id": "daily-mills", "metadata": {}, "source": [ "Note that this will take effect for *all* matplotlib plots, including those not made using seaborn:" ] }, { "cell_type": "code", "execution_count": null, "id": "essential-replica", "metadata": {}, "outputs": [], "source": [ "plt.bar([\"A\", \"B\", \"C\"], [1, 3, 2])" ] }, { "cell_type": "markdown", "id": "naughty-edgar", "metadata": {}, "source": [ "The seaborn theme is decomposed into several distinct sets of parameters that you can control independently:" ] }, { "cell_type": "code", "execution_count": null, "id": "latin-conversion", "metadata": {}, "outputs": [], "source": [ "sns.set_theme(style=\"whitegrid\", palette=\"pastel\")\n", "sns.barplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "markdown", "id": "durable-cycling", "metadata": {}, "source": [ "Pass `None` to preserve the current values for a given set of parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "blessed-chuck", "metadata": {}, "outputs": [], "source": [ "sns.set_theme(style=\"white\", palette=None)\n", "sns.barplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "markdown", "id": "present-writing", "metadata": {}, "source": [ "You can also override any seaborn parameters or define additional parameters that are part of the matplotlib rc system but not included in the seaborn themes:" ] }, { "cell_type": "code", "execution_count": null, "id": "floppy-effectiveness", "metadata": {}, "outputs": [], "source": [ "custom_params = {\"axes.spines.right\": False, \"axes.spines.top\": False}\n", "sns.set_theme(style=\"ticks\", rc=custom_params)\n", "sns.barplot(x=[\"A\", \"B\", \"C\"], y=[1, 3, 2])" ] }, { "cell_type": "code", "execution_count": null, "id": "large-transfer", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_docstrings/stripplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a single numeric variable shows its univariate distribution with points randomly \"jittered\" on the other axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.stripplot(data=tips, x=\"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a second variable splits the strips of points to compare categorical levels of that variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Show vertically-oriented strips by swapping the assignment of the categorical and numerical variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"day\", y=\"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Prior to version 0.12, the levels of the categorical variable had different colors by default. To get the same effect, assign the `hue` variable explicitly:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"day\", legend=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Or you can assign a distinct variable to `hue` to show a multidimensional relationship:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If the `hue` variable is numeric, it will be mapped with a quantitative palette by default (note that this was not the case prior to version 0.12):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use `palette` to control the color mapping, including forcing a categorical mapping by passing the name of a qualitative palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\", palette=\"deep\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, the different levels of the `hue` variable are intermingled in each strip, but setting `dodge=True` will split them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The random jitter can be disabled by setting `jitter=False`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True, jitter=False)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "If plotting in wide-form mode, each numeric column of the dataframe will be mapped to both `x` and `hue`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To change the orientation while in wide-form mode, pass `orient` explicitly:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, orient=\"h\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The `orient` parameter is also useful when both axis variables are numeric, as it will resolve ambiguity about which dimension to group (and jitter) along:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(data=tips, x=\"total_bill\", y=\"size\", orient=\"h\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, the categorical variable will be mapped to discrete indices with a fixed scale (0, 1, ...), even when it is numeric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(\n", " data=tips.query(\"size in [2, 3, 5]\"),\n", " x=\"total_bill\", y=\"size\", orient=\"h\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To disable this behavior and use the original scale of the variable, set `native_scale=True`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(\n", " data=tips.query(\"size in [2, 3, 5]\"),\n", " x=\"total_bill\", y=\"size\", orient=\"h\",\n", " native_scale=True,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Further visual customization can be achieved by passing keyword arguments for :func:`matplotlib.axes.Axes.scatter`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.stripplot(\n", " data=tips, x=\"total_bill\", y=\"day\", hue=\"time\",\n", " jitter=False, s=20, marker=\"D\", linewidth=1, alpha=.1,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To make a plot with multiple facets, it is safer to use :func:`catplot` than to work with :class:`FacetGrid` directly, because :func:`catplot` will ensure that the categorical and hue variables are properly synchronized in each facet:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"time\", y=\"total_bill\", hue=\"sex\", col=\"day\", aspect=.5)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/swarmplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a single numeric variable shows its univariate distribution with points adjusted along on the other axis such that they don't overlap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.swarmplot(data=tips, x=\"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a second variable splits the groups of points to compare categorical levels of that variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Show vertically-oriented swarms by swapping the assignment of the categorical and numerical variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"day\", y=\"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Prior to version 0.12, the levels of the categorical variable had different colors by default. To get the same effect, assign the `hue` variable explicitly:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"day\", legend=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Or you can assign a distinct variable to `hue` to show a multidimensional relationship:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If the `hue` variable is numeric, it will be mapped with a quantitative palette by default (note that this was not the case prior to version 0.12):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Use `palette` to control the color mapping, including forcing a categorical mapping by passing the name of a qualitative palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"size\", palette=\"deep\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, the different levels of the `hue` variable are intermingled in each swarm, but setting `dodge=True` will split them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", dodge=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The \"orientation\" of the plot (defined as the direction along which quantitative relationships are preserved) is usually inferred automatically. But in ambiguous cases, such as when both axis variables are numeric, it can be specified:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"size\", orient=\"h\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When the local density of points is too high, they will be forced to overlap in the \"gutters\" of each swarm and a warning will be issued. Decreasing the size of the points can help to avoid this problem:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(data=tips, x=\"total_bill\", y=\"size\", orient=\"h\", size=3)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, the categorical variable will be mapped to discrete indices with a fixed scale (0, 1, ...), even when it is numeric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(\n", " data=tips.query(\"size in [2, 3, 5]\"),\n", " x=\"total_bill\", y=\"size\", orient=\"h\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To disable this behavior and use the original scale of the variable, set `native_scale=True` (notice how this also changes the order of the variables on the y axis):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(\n", " data=tips.query(\"size in [2, 3, 5]\"),\n", " x=\"total_bill\", y=\"size\", orient=\"h\",\n", " native_scale=True,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Further visual customization can be achieved by passing keyword arguments for :func:`matplotlib.axes.Axes.scatter`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.swarmplot(\n", " data=tips, x=\"total_bill\", y=\"day\",\n", " marker=\"x\", linewidth=1, \n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To make a plot with multiple facets, it is safer to use :func:`catplot` with `kind=\"swarm\"` than to work with :class:`FacetGrid` directly, because :func:`catplot` will ensure that the categorical and hue variables are properly synchronized in each facet:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=tips, kind=\"swarm\",\n", " x=\"time\", y=\"total_bill\", hue=\"sex\", col=\"day\",\n", " aspect=.5\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_docstrings/violinplot.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": null, "id": "cc19031c-bc2f-4294-95ce-3a2d9b86f44d", "metadata": { "editable": true, "slideshow": { "slide_type": "" }, "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "sns.set_theme(style=\"whitegrid\")" ] }, { "cell_type": "raw", "id": "c72b5394-ff5f-42b1-b083-2e42b2ffdf0f", "metadata": {}, "source": [ "The default violinplot represents a distribution two ways: a patch showing a symmetric kernel density estimate (KDE), and the quartiles / whiskers of a box plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "27d578fb-1c20-4d31-b93d-b1b4a053992b", "metadata": {}, "outputs": [], "source": [ "df = sns.load_dataset(\"titanic\")\n", "sns.violinplot(x=df[\"age\"])" ] }, { "cell_type": "raw", "id": "e7d25589-0dc9-48ce-92f9-ab61ffbf964a", "metadata": {}, "source": [ "In a bivariate plot, one of the variables will \"group\" so that multiple violins are drawn:" ] }, { "cell_type": "code", "execution_count": null, "id": "2b851b2c-0011-4cff-8719-11f6138c44e7", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"age\", y=\"class\")" ] }, { "cell_type": "raw", "id": "6d588b32-b14b-4b33-bbd9-69b17f8212a6", "metadata": {}, "source": [ "By default, the orientation of the plot is determined by the variable types, preferring to group by a categorical variable:" ] }, { "cell_type": "code", "execution_count": null, "id": "4810c8e7-0864-496f-8e86-a6527369b9e1", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"class\", y=\"age\", hue=\"alive\")" ] }, { "cell_type": "raw", "id": "402812f2-c024-4179-9fee-fed92f03deb2", "metadata": {}, "source": [ "Pass `fill=False` to draw line-art violins:" ] }, { "cell_type": "code", "execution_count": null, "id": "8e00ce8b-5871-486b-8c55-a4f2e764aa86", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"class\", y=\"age\", hue=\"alive\", fill=False)" ] }, { "cell_type": "raw", "id": "8350abce-6a40-4e18-9501-7d358192471b", "metadata": {}, "source": [ "Draw \"split\" violins to take up less space, and only show the data quarties:" ] }, { "cell_type": "code", "execution_count": null, "id": "2ae35376-5272-496c-afec-c60a3426f1bf", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"class\", y=\"age\", hue=\"alive\", split=True, inner=\"quart\")" ] }, { "cell_type": "raw", "id": "90f4263f-7294-4ad5-bff4-25d7d796cb45", "metadata": {}, "source": [ "Add a small gap between the dodged violins:" ] }, { "cell_type": "code", "execution_count": null, "id": "26cb5b89-496d-4893-8914-ca8b6fbf97b7", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"class\", y=\"age\", hue=\"alive\", split=True, gap=.1, inner=\"quart\")" ] }, { "cell_type": "raw", "id": "bbea49e0-7b08-4b25-8686-1d5404b71601", "metadata": {}, "source": [ "Starting in version 0.13.0, it is possible to \"split\" single violins:" ] }, { "cell_type": "code", "execution_count": null, "id": "ba261531-a280-44e5-b8c0-bcc5a53f60bf", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"class\", y=\"age\", split=True, inner=\"quart\")" ] }, { "cell_type": "raw", "id": "7c4dafa1-2747-4b43-ba4a-4c9b32778086", "metadata": {}, "source": [ "Represent every observation inside the distribution by setting `inner=\"stick\"` or `inner=\"point\"`:" ] }, { "cell_type": "code", "execution_count": null, "id": "00b5f00e-a515-4e53-9d73-d13b045cd4c8", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"age\", y=\"deck\", inner=\"point\")" ] }, { "cell_type": "raw", "id": "23c13695-cd01-4da8-bc89-2519ae445f9f", "metadata": {}, "source": [ "Normalize the width of each violin to represent the number of observations:" ] }, { "cell_type": "code", "execution_count": null, "id": "be59f17e-824e-4a8c-a0e1-a27874a05df6", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"age\", y=\"deck\", inner=\"point\", density_norm=\"count\")" ] }, { "cell_type": "raw", "id": "abe650fb-4d26-4bac-97f3-f451a3872cf5", "metadata": {}, "source": [ "By default, the KDE will smooth past the extremes of the observed data; set `cut=0` to prevent this:" ] }, { "cell_type": "code", "execution_count": null, "id": "82556de0-3756-426c-a591-9af6ed6c45d4", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"age\", y=\"alive\", cut=0, inner=\"stick\")" ] }, { "cell_type": "raw", "id": "abfb9e78-d524-4536-90ef-c71834b055f9", "metadata": {}, "source": [ "The `bw_adjust` parameter controls the amount of smoothing:" ] }, { "cell_type": "code", "execution_count": null, "id": "8d17e1e3-e0f4-4d2c-ac6e-aec42ed75390", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"age\", y=\"alive\", bw_adjust=.5, inner=\"stick\")" ] }, { "cell_type": "raw", "id": "407bc513-5b7f-418c-8ffe-ec488836586d", "metadata": {}, "source": [ "By default, the violins are drawn at fixed positions on a categorical scale, even if the grouping variable is numeric. Starting in version 0.13.0, pass the `native_scale=True` parameter to preserve the original scale on both axes:" ] }, { "cell_type": "code", "execution_count": null, "id": "e7b6d901-9a97-4716-8d24-1b30145e9c57", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(x=df[\"age\"].round(-1) + 5, y=df[\"fare\"], native_scale=True)" ] }, { "cell_type": "raw", "id": "790e3989-0b47-4e77-9bdb-dc757d1e938c", "metadata": {}, "source": [ "When using a categorical scale, the `formatter` parameter accepts a function that defines categories:" ] }, { "cell_type": "code", "execution_count": null, "id": "28a769d4-3e23-4b53-a9ef-391d5fc24201", "metadata": {}, "outputs": [], "source": [ "decades = lambda x: f\"{int(x)}–{int(x + 10)}\"\n", "sns.violinplot(x=df[\"age\"].round(-1), y=df[\"fare\"], formatter=decades)" ] }, { "cell_type": "raw", "id": "6f914d73-7a0c-4fbc-8432-40c4f0577857", "metadata": {}, "source": [ "By default, the \"inner\" representation scales with the `linewidth` and `linecolor` parameters:" ] }, { "cell_type": "code", "execution_count": null, "id": "18cb2afd-8487-40bd-b3f2-1f83243ffa3c", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"age\", linewidth=1, linecolor=\"k\")" ] }, { "cell_type": "raw", "id": "ca2ef541-c07f-4853-ba98-ce75855ba262", "metadata": {}, "source": [ "Use `inner_kws` to pass parameters directly to the inner plotting functions:" ] }, { "cell_type": "code", "execution_count": null, "id": "934f91bc-2698-4c07-92cf-4e6039c801b2", "metadata": {}, "outputs": [], "source": [ "sns.violinplot(data=df, x=\"age\", inner_kws=dict(box_width=15, whis_width=2, color=\".8\"))" ] }, { "cell_type": "code", "execution_count": null, "id": "4aa00d3c-f016-4db8-b6b0-da4e6a327831", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_static/copybutton.js ================================================ // originally taken from scikit-learn's Sphinx theme $(document).ready(function() { /* Add a [>>>] button on the top-right corner of code samples to hide * the >>> and ... prompts and the output and thus make the code * copyable. * Note: This JS snippet was taken from the official python.org * documentation site.*/ var div = $('.highlight-python .highlight,' + '.highlight-python3 .highlight,' + '.highlight-pycon .highlight') var pre = div.find('pre'); // get the styles from the current theme pre.parent().parent().css('position', 'relative'); var hide_text = 'Hide the prompts and output'; var show_text = 'Show the prompts and output'; var border_width = pre.css('border-top-width'); var border_style = pre.css('border-top-style'); var border_color = pre.css('border-top-color'); var button_styles = { 'cursor':'pointer', 'position': 'absolute', 'top': '0', 'right': '0', 'border-color': border_color, 'border-style': border_style, 'border-width': border_width, 'color': border_color, 'text-size': '75%', 'font-family': 'monospace', 'padding-left': '0.2em', 'padding-right': '0.2em' } // create and add the button to all the code blocks that contain >>> div.each(function(index) { var jthis = $(this); if (jthis.find('.gp').length > 0) { var button = $('>>>'); button.css(button_styles) button.attr('title', hide_text); jthis.prepend(button); } // tracebacks (.gt) contain bare text elements that need to be // wrapped in a span to work with .nextUntil() (see later) jthis.find('pre:has(.gt)').contents().filter(function() { return ((this.nodeType == 3) && (this.data.trim().length > 0)); }).wrap(''); }); // define the behavior of the button when it's clicked $('.copybutton').toggle( function() { var button = $(this); button.parent().find('.go, .gp, .gt').hide(); button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'hidden'); button.css('text-decoration', 'line-through'); button.attr('title', show_text); }, function() { var button = $(this); button.parent().find('.go, .gp, .gt').show(); button.next('pre').find('.gt').nextUntil('.gp, .go').css('visibility', 'visible'); button.css('text-decoration', 'none'); button.attr('title', hide_text); }); }); ================================================ FILE: doc/_static/css/custom.css ================================================ /**** Overriding theme defaults ****/ html[data-theme=light]{ --pst-color-primary: rgb(52, 54, 99); --pst-color-secondary: rgb(107, 161, 174); --pst-color-link: rgb(74, 105, 145); --pst-color-inline-code: rgb(96, 141, 130); } :root { --pst-font-size-h1: 38px; --pst-font-size-h2: 32px; --pst-font-size-h3: 27px; --pst-font-size-h4: 22px; --pst-font-size-h5: 18px; --pst-font-size-h6: 15px; } /* Syntax highlighting */ /* string literals */ html[data-theme=light] .highlight .s2 { color: rgb(74, 105, 145); font-weight: normal; } /* number literals */ html[data-theme=light] .highlight .mi { color: rgb(136, 97, 153); font-weight: normal; } html[data-theme=light] .highlight .mf { color: rgb(136, 97, 153); font-weight: normal; } /* operators */ html[data-theme=light] .highlight .o { color: rgb(219, 164, 117); font-weight: bold; } /* builtins */ html[data-theme=light] .highlight .kc { color: rgb(107, 161, 174); font-weight: bold; } /* Use full page width without sidebars */ .bd-content { max-width: 100%; flex-grow: 1; } /* Function signature customization */ dt { font-weight: 500; color: rgb(52, 54, 99); } span.default_value { color: rgb(124, 141, 138); } /* highlight over function signature after link */ dt:target, span.highlighted { background-color: #fdebba; } /* *********************************************************************** */ /* --- Badges for categorizing release notes --- */ .label, .badge { display: inline-block; padding: 2px 4px; font-size: 11.844px; /* font-weight: bold; */ line-height: 13px; color: #ffffff; vertical-align: baseline; white-space: nowrap; /* text-shadow: 0 -1px 0 rgba(0, 0, 0, 0.25); */ background-color: #999999; } .badge { padding-left: 9px; padding-right: 9px; -webkit-border-radius: 9px; -moz-border-radius: 9px; border-radius: 9px; opacity: 70%; } .badge-api { background-color: #c44e52; } .badge-defaults { background-color: #dd8452; } .badge-docs { background-color: #8172b3; } .badge-feature { background-color: #55a868; } .badge-enhancement { background-color: #4c72b0; } .badge-fix { background-color: #ccb974; } .badge-build { background-color: #937860; } ================================================ FILE: doc/_templates/autosummary/base.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. auto{{ objtype }}:: {{ objname }} ================================================ FILE: doc/_templates/autosummary/class.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} {% block methods %} .. automethod:: __init__ {% if methods %} .. rubric:: Methods .. autosummary:: :toctree: ./ {% for item in methods %} ~{{ name }}.{{ item }} {%- endfor %} {% endif %} {% endblock %} {% block attributes %} {% if attributes %} .. rubric:: Attributes .. autosummary:: {% for item in attributes %} ~{{ name }}.{{ item }} {%- endfor %} {% endif %} {% endblock %} ================================================ FILE: doc/_templates/autosummary/object.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} ================================================ FILE: doc/_templates/autosummary/plot.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} {% block methods %} Methods ~~~~~~~ .. rubric:: Specification methods .. autosummary:: :toctree: ./ :nosignatures: ~Plot.add ~Plot.scale .. rubric:: Subplot methods .. autosummary:: :toctree: ./ :nosignatures: ~Plot.facet ~Plot.pair .. rubric:: Customization methods .. autosummary:: :toctree: ./ :nosignatures: ~Plot.layout ~Plot.label ~Plot.limit ~Plot.share ~Plot.theme .. rubric:: Integration methods .. autosummary:: :toctree: ./ :nosignatures: ~Plot.on .. rubric:: Output methods .. autosummary:: :toctree: ./ :nosignatures: ~Plot.plot ~Plot.save ~Plot.show {% endblock %} .. _plot_config: Configuration ~~~~~~~~~~~~~ The :class:`Plot` object's default behavior can be configured through its :attr:`Plot.config` attribute. Notice that this is a property of the class, not a method on an instance. .. include:: ../docstrings/objects.Plot.config.rst ================================================ FILE: doc/_templates/autosummary/scale.rst ================================================ {{ fullname | escape | underline}} .. currentmodule:: {{ module }} .. autoclass:: {{ objname }} .. automethod:: tick .. automethod:: label ================================================ FILE: doc/_templates/layout.html ================================================ {% extends "!layout.html" %} {%- block footer %} {%- endblock %} ================================================ FILE: doc/_templates/version.html ================================================ ================================================ FILE: doc/_tutorial/Makefile ================================================ rst_files := $(patsubst %.ipynb,../tutorial/%.rst,$(wildcard *.ipynb)) export MPLBACKEND := module://matplotlib_inline.backend_inline tutorial: ${rst_files} ../tutorial/%.rst: %.ipynb ../tools/nb_to_doc.py $*.ipynb ../tutorial clean: rm -rf ../tutorial ================================================ FILE: doc/_tutorial/aesthetics.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _aesthetics_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Controlling figure aesthetics\n", "=============================\n" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Drawing attractive figures is important. When making figures for yourself, as you explore a dataset, it's nice to have plots that are pleasant to look at. Visualizations are also central to communicating quantitative insights to an audience, and in that setting it's even more necessary to have figures that catch the attention and draw a viewer in.\n", "\n", "Matplotlib is highly customizable, but it can be hard to know what settings to tweak to achieve an attractive plot. Seaborn comes with a number of customized themes and a high-level interface for controlling the look of matplotlib figures." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "np.random.seed(sum(map(ord, \"aesthetics\")))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Let's define a simple function to plot some offset sine waves, which will help us see the different stylistic parameters we can tweak." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def sinplot(n=10, flip=1):\n", " x = np.linspace(0, 14, 100)\n", " for i in range(1, n + 1):\n", " plt.plot(x, np.sin(x + i * .5) * (n + 2 - i) * flip)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This is what the plot looks like with matplotlib defaults:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sinplot()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To switch to seaborn defaults, simply call the :func:`set_theme` function." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_theme()\n", "sinplot()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "(Note that in versions of seaborn prior to 0.8, :func:`set_theme` was called on import. On later versions, it must be explicitly invoked).\n", "\n", "Seaborn splits matplotlib parameters into two independent groups. The first group sets the aesthetic style of the plot, and the second scales various elements of the figure so that it can be easily incorporated into different contexts.\n", "\n", "The interface for manipulating these parameters are two pairs of functions. To control the style, use the :func:`axes_style` and :func:`set_style` functions. To scale the plot, use the :func:`plotting_context` and :func:`set_context` functions. In both cases, the first function returns a dictionary of parameters and the second sets the matplotlib defaults.\n", "\n", ".. _axes_style:\n", "\n", "Seaborn figure styles\n", "---------------------\n", "\n", "There are five preset seaborn themes: ``darkgrid``, ``whitegrid``, ``dark``, ``white``, and ``ticks``. They are each suited to different applications and personal preferences. The default theme is ``darkgrid``. As mentioned above, the grid helps the plot serve as a lookup table for quantitative information, and the white-on grey helps to keep the grid from competing with lines that represent data. The ``whitegrid`` theme is similar, but it is better suited to plots with heavy data elements:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"whitegrid\")\n", "data = np.random.normal(size=(20, 6)) + np.arange(6) / 2\n", "sns.boxplot(data=data);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "For many plots, (especially for settings like talks, where you primarily want to use figures to provide impressions of patterns in the data), the grid is less necessary." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"dark\")\n", "sinplot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"white\")\n", "sinplot()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Sometimes you might want to give a little extra structure to the plots, which is where ticks come in handy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"ticks\")\n", "sinplot()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _remove_spines:\n", "\n", "Removing axes spines\n", "--------------------\n", "\n", "Both the ``white`` and ``ticks`` styles can benefit from removing the top and right axes spines, which are not needed. The seaborn function :func:`despine` can be called to remove them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sinplot()\n", "sns.despine()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Some plots benefit from offsetting the spines away from the data, which can also be done when calling :func:`despine`. When the ticks don't cover the whole range of the axis, the ``trim`` parameter will limit the range of the surviving spines." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f, ax = plt.subplots()\n", "sns.violinplot(data=data)\n", "sns.despine(offset=10, trim=True);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You can also control which spines are removed with additional arguments to :func:`despine`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"whitegrid\")\n", "sns.boxplot(data=data, palette=\"deep\")\n", "sns.despine(left=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Temporarily setting figure style\n", "--------------------------------\n", "\n", "Although it's easy to switch back and forth, you can also use the :func:`axes_style` function in a ``with`` statement to temporarily set plot parameters. This also allows you to make figures with differently-styled axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f = plt.figure(figsize=(6, 6))\n", "gs = f.add_gridspec(2, 2)\n", "\n", "with sns.axes_style(\"darkgrid\"):\n", " ax = f.add_subplot(gs[0, 0])\n", " sinplot(6)\n", " \n", "with sns.axes_style(\"white\"):\n", " ax = f.add_subplot(gs[0, 1])\n", " sinplot(6)\n", "\n", "with sns.axes_style(\"ticks\"):\n", " ax = f.add_subplot(gs[1, 0])\n", " sinplot(6)\n", "\n", "with sns.axes_style(\"whitegrid\"):\n", " ax = f.add_subplot(gs[1, 1])\n", " sinplot(6)\n", " \n", "f.tight_layout()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Overriding elements of the seaborn styles\n", "-----------------------------------------\n", "\n", "If you want to customize the seaborn styles, you can pass a dictionary of parameters to the ``rc`` argument of :func:`axes_style` and :func:`set_style`. Note that you can only override the parameters that are part of the style definition through this method. (However, the higher-level :func:`set_theme` function takes a dictionary of any matplotlib parameters).\n", "\n", "If you want to see what parameters are included, you can just call the function with no arguments, which will return the current settings:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.axes_style()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You can then set different versions of these parameters:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_style(\"darkgrid\", {\"axes.facecolor\": \".9\"})\n", "sinplot()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _plotting_context:\n", "\n", "Scaling plot elements\n", "---------------------\n", "\n", "A separate set of parameters control the scale of plot elements, which should let you use the same code to make plots that are suited for use in settings where larger or smaller plots are appropriate.\n", "\n", "First let's reset the default parameters by calling :func:`set_theme`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The four preset contexts, in order of relative size, are ``paper``, ``notebook``, ``talk``, and ``poster``. The ``notebook`` style is the default, and was used in the plots above." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_context(\"paper\")\n", "sinplot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_context(\"talk\")\n", "sinplot()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_context(\"poster\")\n", "sinplot()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Most of what you now know about the style functions should transfer to the context functions.\n", "\n", "You can call :func:`set_context` with one of these names to set the parameters, and you can override the parameters by providing a dictionary of parameter values.\n", "\n", "You can also independently scale the size of the font elements when changing the context. (This option is also available through the top-level :func:`set` function)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_context(\"notebook\", font_scale=1.5, rc={\"lines.linewidth\": 2.5})\n", "sinplot()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Similarly, you can temporarily control the scale of figures nested under a ``with`` statement.\n", "\n", "Both the style and the context can be quickly configured with the :func:`set` function. This function also sets the default color palette, but that will be covered in more detail in the :ref:`next section ` of the tutorial." ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/axis_grids.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _grid_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Building structured multi-plot grids\n", "====================================\n" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When exploring multi-dimensional data, a useful approach is to draw multiple instances of the same plot on different subsets of your dataset. This technique is sometimes called either \"lattice\" or \"trellis\" plotting, and it is related to the idea of `\"small multiples\" `_. It allows a viewer to quickly extract a large amount of information about a complex dataset. Matplotlib offers good support for making figures with multiple axes; seaborn builds on top of this to directly link the structure of the plot to the structure of your dataset.\n", "\n", "The :doc:`figure-level ` functions are built on top of the objects discussed in this chapter of the tutorial. In most cases, you will want to work with those functions. They take care of some important bookkeeping that synchronizes the multiple plots in each grid. This chapter explains how the underlying objects work, which may be useful for advanced applications." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "import matplotlib.pyplot as plt" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "sns.set_theme(style=\"ticks\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import numpy as np\n", "np.random.seed(sum(map(ord, \"axis_grids\")))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _facet_grid:\n", "\n", "Conditional small multiples\n", "---------------------------\n", "\n", "The :class:`FacetGrid` class is useful when you want to visualize the distribution of a variable or the relationship between multiple variables separately within subsets of your dataset. A :class:`FacetGrid` can be drawn with up to three dimensions: ``row``, ``col``, and ``hue``. The first two have obvious correspondence with the resulting array of axes; think of the hue variable as a third dimension along a depth axis, where different levels are plotted with different colors.\n", "\n", "Each of :func:`relplot`, :func:`displot`, :func:`catplot`, and :func:`lmplot` use this object internally, and they return the object when they are finished so that it can be used for further tweaking.\n", "\n", "The class is used by initializing a :class:`FacetGrid` object with a dataframe and the names of the variables that will form the row, column, or hue dimensions of the grid. These variables should be categorical or discrete, and then the data at each level of the variable will be used for a facet along that axis. For example, say we wanted to examine differences between lunch and dinner in the ``tips`` dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "g = sns.FacetGrid(tips, col=\"time\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Initializing the grid like this sets up the matplotlib figure and axes, but doesn't draw anything on them.\n", "\n", "The main approach for visualizing data on this grid is with the :meth:`FacetGrid.map` method. Provide it with a plotting function and the name(s) of variable(s) in the dataframe to plot. Let's look at the distribution of tips in each of these subsets, using a histogram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"time\")\n", "g.map(sns.histplot, \"tip\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This function will draw the figure and annotate the axes, hopefully producing a finished plot in one step. To make a relational plot, just pass multiple variable names. You can also provide keyword arguments, which will be passed to the plotting function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"sex\", hue=\"smoker\")\n", "g.map(sns.scatterplot, \"total_bill\", \"tip\", alpha=.7)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "There are several options for controlling the look of the grid that can be passed to the class constructor." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, row=\"smoker\", col=\"time\", margin_titles=True)\n", "g.map(sns.regplot, \"size\", \"total_bill\", color=\".3\", fit_reg=False, x_jitter=.1)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Note that ``margin_titles`` isn't formally supported by the matplotlib API, and may not work well in all cases. In particular, it currently can't be used with a legend that lies outside of the plot.\n", "\n", "The size of the figure is set by providing the height of *each* facet, along with the aspect ratio:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"day\", height=4, aspect=.5)\n", "g.map(sns.barplot, \"sex\", \"total_bill\", order=[\"Male\", \"Female\"])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The default ordering of the facets is derived from the information in the DataFrame. If the variable used to define facets has a categorical type, then the order of the categories is used. Otherwise, the facets will be in the order of appearance of the category levels. It is possible, however, to specify an ordering of any facet dimension with the appropriate ``*_order`` parameter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "ordered_days = tips.day.value_counts().index\n", "g = sns.FacetGrid(tips, row=\"day\", row_order=ordered_days,\n", " height=1.7, aspect=4,)\n", "g.map(sns.kdeplot, \"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Any seaborn color palette (i.e., something that can be passed to :func:`color_palette()`) can be provided. You can also use a dictionary that maps the names of values in the ``hue`` variable to valid matplotlib colors:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "pal = dict(Lunch=\"seagreen\", Dinner=\".7\")\n", "g = sns.FacetGrid(tips, hue=\"time\", palette=pal, height=5)\n", "g.map(sns.scatterplot, \"total_bill\", \"tip\", s=100, alpha=.5)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If you have many levels of one variable, you can plot it along the columns but \"wrap\" them so that they span multiple rows. When doing this, you cannot use a ``row`` variable." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "attend = sns.load_dataset(\"attention\").query(\"subject <= 12\")\n", "g = sns.FacetGrid(attend, col=\"subject\", col_wrap=4, height=2, ylim=(0, 10))\n", "g.map(sns.pointplot, \"solutions\", \"score\", order=[1, 2, 3], color=\".3\", errorbar=None)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Once you've drawn a plot using :meth:`FacetGrid.map` (which can be called multiple times), you may want to adjust some aspects of the plot. There are also a number of methods on the :class:`FacetGrid` object for manipulating the figure at a higher level of abstraction. The most general is :meth:`FacetGrid.set`, and there are other more specialized methods like :meth:`FacetGrid.set_axis_labels`, which respects the fact that interior facets do not have axis labels. For example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "with sns.axes_style(\"white\"):\n", " g = sns.FacetGrid(tips, row=\"sex\", col=\"smoker\", margin_titles=True, height=2.5)\n", "g.map(sns.scatterplot, \"total_bill\", \"tip\", color=\"#334488\")\n", "g.set_axis_labels(\"Total bill (US Dollars)\", \"Tip\")\n", "g.set(xticks=[10, 30, 50], yticks=[2, 6, 10])\n", "g.figure.subplots_adjust(wspace=.02, hspace=.02)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "For even more customization, you can work directly with the underling matplotlib ``Figure`` and ``Axes`` objects, which are stored as member attributes at ``figure`` and ``axes_dict``, respectively. When making a figure without row or column faceting, you can also use the ``ax`` attribute to directly access the single axes." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, col=\"smoker\", margin_titles=True, height=4)\n", "g.map(plt.scatter, \"total_bill\", \"tip\", color=\"#338844\", edgecolor=\"white\", s=50, lw=1)\n", "for ax in g.axes_dict.values():\n", " ax.axline((0, 0), slope=.2, c=\".2\", ls=\"--\", zorder=0)\n", "g.set(xlim=(0, 60), ylim=(0, 14))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _custom_map_func:\n", "\n", "Using custom functions\n", "----------------------\n", "\n", "You're not limited to existing matplotlib and seaborn functions when using :class:`FacetGrid`. However, to work properly, any function you use must follow a few rules:\n", "\n", "1. It must plot onto the \"currently active\" matplotlib ``Axes``. This will be true of functions in the ``matplotlib.pyplot`` namespace, and you can call :func:`matplotlib.pyplot.gca` to get a reference to the current ``Axes`` if you want to work directly with its methods.\n", "2. It must accept the data that it plots in positional arguments. Internally, :class:`FacetGrid` will pass a ``Series`` of data for each of the named positional arguments passed to :meth:`FacetGrid.map`.\n", "3. It must be able to accept ``color`` and ``label`` keyword arguments, and, ideally, it will do something useful with them. In most cases, it's easiest to catch a generic dictionary of ``**kwargs`` and pass it along to the underlying plotting function.\n", "\n", "Let's look at minimal example of a function you can plot with. This function will just take a single vector of data for each facet:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from scipy import stats\n", "def quantile_plot(x, **kwargs):\n", " quantiles, xr = stats.probplot(x, fit=False)\n", " plt.scatter(xr, quantiles, **kwargs)\n", " \n", "g = sns.FacetGrid(tips, col=\"sex\", height=4)\n", "g.map(quantile_plot, \"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If we want to make a bivariate plot, you should write the function so that it accepts the x-axis variable first and the y-axis variable second:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def qqplot(x, y, **kwargs):\n", " _, xr = stats.probplot(x, fit=False)\n", " _, yr = stats.probplot(y, fit=False)\n", " plt.scatter(xr, yr, **kwargs)\n", " \n", "g = sns.FacetGrid(tips, col=\"smoker\", height=4)\n", "g.map(qqplot, \"total_bill\", \"tip\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Because :func:`matplotlib.pyplot.scatter` accepts ``color`` and ``label`` keyword arguments and does the right thing with them, we can add a hue facet without any difficulty:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(tips, hue=\"time\", col=\"sex\", height=4)\n", "g.map(qqplot, \"total_bill\", \"tip\")\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Sometimes, though, you'll want to map a function that doesn't work the way you expect with the ``color`` and ``label`` keyword arguments. In this case, you'll want to explicitly catch them and handle them in the logic of your custom function. For example, this approach will allow use to map :func:`matplotlib.pyplot.hexbin`, which otherwise does not play well with the :class:`FacetGrid` API:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def hexbin(x, y, color, **kwargs):\n", " cmap = sns.light_palette(color, as_cmap=True)\n", " plt.hexbin(x, y, gridsize=15, cmap=cmap, **kwargs)\n", "\n", "with sns.axes_style(\"dark\"):\n", " g = sns.FacetGrid(tips, hue=\"time\", col=\"time\", height=4)\n", "g.map(hexbin, \"total_bill\", \"tip\", extent=[0, 50, 0, 10]);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _pair_grid:\n", "\n", "Plotting pairwise data relationships\n", "------------------------------------\n", "\n", ":class:`PairGrid` also allows you to quickly draw a grid of small subplots using the same plot type to visualize data in each. In a :class:`PairGrid`, each row and column is assigned to a different variable, so the resulting plot shows each pairwise relationship in the dataset. This style of plot is sometimes called a \"scatterplot matrix\", as this is the most common way to show each relationship, but :class:`PairGrid` is not limited to scatterplots.\n", "\n", "It's important to understand the differences between a :class:`FacetGrid` and a :class:`PairGrid`. In the former, each facet shows the same relationship conditioned on different levels of other variables. In the latter, each plot shows a different relationship (although the upper and lower triangles will have mirrored plots). Using :class:`PairGrid` can give you a very quick, very high-level summary of interesting relationships in your dataset.\n", "\n", "The basic usage of the class is very similar to :class:`FacetGrid`. First you initialize the grid, then you pass plotting function to a ``map`` method and it will be called on each subplot. There is also a companion function, :func:`pairplot` that trades off some flexibility for faster plotting.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "iris = sns.load_dataset(\"iris\")\n", "g = sns.PairGrid(iris)\n", "g.map(sns.scatterplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's possible to plot a different function on the diagonal to show the univariate distribution of the variable in each column. Note that the axis ticks won't correspond to the count or density axis of this plot, though." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(iris)\n", "g.map_diag(sns.histplot)\n", "g.map_offdiag(sns.scatterplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A very common way to use this plot colors the observations by a separate categorical variable. For example, the iris dataset has four measurements for each of three different species of iris flowers so you can see how they differ." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(iris, hue=\"species\")\n", "g.map_diag(sns.histplot)\n", "g.map_offdiag(sns.scatterplot)\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default every numeric column in the dataset is used, but you can focus on particular relationships if you want." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(iris, vars=[\"sepal_length\", \"sepal_width\"], hue=\"species\")\n", "g.map(sns.scatterplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's also possible to use a different function in the upper and lower triangles to emphasize different aspects of the relationship." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(iris)\n", "g.map_upper(sns.scatterplot)\n", "g.map_lower(sns.kdeplot)\n", "g.map_diag(sns.kdeplot, lw=3, legend=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The square grid with identity relationships on the diagonal is actually just a special case, and you can plot with different variables in the rows and columns." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(tips, y_vars=[\"tip\"], x_vars=[\"total_bill\", \"size\"], height=4)\n", "g.map(sns.regplot, color=\".3\")\n", "g.set(ylim=(-1, 11), yticks=[0, 5, 10])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Of course, the aesthetic attributes are configurable. For instance, you can use a different palette (say, to show an ordering of the ``hue`` variable) and pass keyword arguments into the plotting functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(tips, hue=\"size\", palette=\"GnBu_d\")\n", "g.map(plt.scatter, s=50, edgecolor=\"white\")\n", "g.add_legend()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ":class:`PairGrid` is flexible, but to take a quick look at a dataset, it can be easier to use :func:`pairplot`. This function uses scatterplots and histograms by default, although a few other kinds will be added (currently, you can also plot regression plots on the off-diagonals and KDEs on the diagonal)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(iris, hue=\"species\", height=2.5)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You can also control the aesthetics of the plot with keyword arguments, and it returns the :class:`PairGrid` instance for further tweaking." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.pairplot(iris, hue=\"species\", palette=\"Set2\", diag_kind=\"kde\", height=2.5)" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/categorical.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _categorical_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Visualizing categorical data\n", "============================\n", " \n", "In the :ref:`relational plot tutorial ` we saw how to use different visual representations to show the relationship between multiple variables in a dataset. In the examples, we focused on cases where the main relationship was between two numerical variables. If one of the main variables is \"categorical\" (divided into discrete groups) it may be helpful to use a more specialized approach to visualization.\n", "\n", "In seaborn, there are several different ways to visualize a relationship involving categorical data. Similar to the relationship between :func:`relplot` and either :func:`scatterplot` or :func:`lineplot`, there are two ways to make these plots. There are a number of axes-level functions for plotting categorical data in different ways and a figure-level interface, :func:`catplot`, that gives unified higher-level access to them.\n", "\n", "It's helpful to think of the different categorical plot kinds as belonging to three different families, which we'll discuss in detail below. They are:\n", "\n", "Categorical scatterplots:\n", "\n", "- :func:`stripplot` (with ``kind=\"strip\"``; the default)\n", "- :func:`swarmplot` (with ``kind=\"swarm\"``)\n", "\n", "Categorical distribution plots:\n", "\n", "- :func:`boxplot` (with ``kind=\"box\"``)\n", "- :func:`violinplot` (with ``kind=\"violin\"``)\n", "- :func:`boxenplot` (with ``kind=\"boxen\"``)\n", "\n", "Categorical estimate plots:\n", "\n", "- :func:`pointplot` (with ``kind=\"point\"``)\n", "- :func:`barplot` (with ``kind=\"bar\"``)\n", "- :func:`countplot` (with ``kind=\"count\"``)\n", "\n", "These families represent the data using different levels of granularity. When deciding which to use, you'll have to think about the question that you want to answer. The unified API makes it easy to switch between different kinds and see your data from several perspectives.\n", "\n", "In this tutorial, we'll mostly focus on the figure-level interface, :func:`catplot`. Remember that this function is a higher-level interface each of the functions above, so we'll reference them when we show each kind of plot, keeping the more verbose kind-specific API documentation at hand." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set_theme(style=\"ticks\", color_codes=True)\n", "np.random.seed(sum(map(ord, \"categorical\")))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Categorical scatterplots\n", "------------------------\n", "\n", "The default representation of the data in :func:`catplot` uses a scatterplot. There are actually two different categorical scatter plots in seaborn. They take different approaches to resolving the main challenge in representing categorical data with a scatter plot, which is that all of the points belonging to one category would fall on the same position along the axis corresponding to the categorical variable. The approach used by :func:`stripplot`, which is the default \"kind\" in :func:`catplot` is to adjust the positions of points on the categorical axis with a small amount of random \"jitter\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.catplot(data=tips, x=\"day\", y=\"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ``jitter`` parameter controls the magnitude of jitter or disables it altogether:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"day\", y=\"total_bill\", jitter=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The second approach adjusts the points along the categorical axis using an algorithm that prevents them from overlapping. It can give a better representation of the distribution of observations, although it only works well for relatively small datasets. This kind of plot is sometimes called a \"beeswarm\" and is drawn in seaborn by :func:`swarmplot`, which is activated by setting ``kind=\"swarm\"`` in :func:`catplot`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"day\", y=\"total_bill\", kind=\"swarm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Similar to the relational plots, it's possible to add another dimension to a categorical plot by using a ``hue`` semantic. (The categorical plots do not currently support ``size`` or ``style`` semantics). Each different categorical plotting function handles the ``hue`` semantic differently. For the scatter plots, it is only necessary to change the color of the points:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"day\", y=\"total_bill\", hue=\"sex\", kind=\"swarm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Unlike with numerical data, it is not always obvious how to order the levels of the categorical variable along its axis. In general, the seaborn categorical plotting functions try to infer the order of categories from the data. If your data have a pandas ``Categorical`` datatype, then the default order of the categories can be set there. If the variable passed to the categorical axis looks numerical, the levels will be sorted. But, by default, the data are still treated as categorical and drawn at ordinal positions on the categorical axes (specifically, at 0, 1, ...) even when numbers are used to label them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips.query(\"size != 3\"), x=\"size\", y=\"total_bill\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As of v0.13.0, all categorical plotting functions have a `native_scale` parameter, which can be set to `True` when you want to use numeric or datetime data for categorical grouping without changing the underlying data properties: " ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips.query(\"size != 3\"), x=\"size\", y=\"total_bill\", native_scale=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The other option for choosing a default ordering is to take the levels of the category as they appear in the dataset. The ordering can also be controlled on a plot-specific basis using the ``order`` parameter. This can be important when drawing multiple categorical plots in the same figure, which we'll see more of below:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"smoker\", y=\"tip\", order=[\"No\", \"Yes\"])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "We've referred to the idea of \"categorical axis\". In these examples, that's always corresponded to the horizontal axis. But it's often helpful to put the categorical variable on the vertical axis (particularly when the category names are relatively long or there are many categories). To do this, swap the assignment of variables to axes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"total_bill\", y=\"day\", hue=\"time\", kind=\"swarm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Comparing distributions\n", "-----------------------\n", "\n", "As the size of the dataset grows, categorical scatter plots become limited in the information they can provide about the distribution of values within each category. When this happens, there are several approaches for summarizing the distributional information in ways that facilitate easy comparisons across the category levels.\n", "\n", "Boxplots\n", "^^^^^^^^\n", "\n", "The first is the familiar :func:`boxplot`. This kind of plot shows the three quartile values of the distribution along with extreme values. The \"whiskers\" extend to points that lie within 1.5 IQRs of the lower and upper quartile, and then observations that fall outside this range are displayed independently. This means that each value in the boxplot corresponds to an actual observation in the data." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"day\", y=\"total_bill\", kind=\"box\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When adding a ``hue`` semantic, the box for each level of the semantic variable is made narrower and shifted along the categorical axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, x=\"day\", y=\"total_bill\", hue=\"smoker\", kind=\"box\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This behavior is called \"dodging\", and it is controlled by the `dodge` parameter. By default (as of v0.13.0), elements dodge only if they would otherwise overlap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips[\"weekend\"] = tips[\"day\"].isin([\"Sat\", \"Sun\"])\n", "sns.catplot(data=tips, x=\"day\", y=\"total_bill\", hue=\"weekend\", kind=\"box\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A related function, :func:`boxenplot`, draws a plot that is similar to a box plot but optimized for showing more information about the shape of the distribution. It is best suited for larger datasets:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "diamonds = sns.load_dataset(\"diamonds\")\n", "sns.catplot(\n", " data=diamonds.sort_values(\"color\"),\n", " x=\"color\", y=\"price\", kind=\"boxen\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Violinplots\n", "^^^^^^^^^^^\n", "\n", "A different approach is a :func:`violinplot`, which combines a boxplot with the kernel density estimation procedure described in the :ref:`distributions ` tutorial:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\", kind=\"violin\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This approach uses the kernel density estimate to provide a richer description of the distribution of values. Additionally, the quartile and whisker values from the boxplot are shown inside the violin. The downside is that, because the violinplot uses a KDE, there are some other parameters that may need tweaking, adding some complexity relative to the straightforward boxplot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=tips, x=\"total_bill\", y=\"day\", hue=\"sex\",\n", " kind=\"violin\", bw_adjust=.5, cut=0,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's also possible to \"split\" the violins, which can allow for a more efficient use of space:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=tips, x=\"day\", y=\"total_bill\", hue=\"sex\",\n", " kind=\"violin\", split=True,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Finally, there are several options for the plot that is drawn on the interior of the violins, including ways to show each individual observation instead of the summary boxplot values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=tips, x=\"day\", y=\"total_bill\", hue=\"sex\",\n", " kind=\"violin\", inner=\"stick\", split=True, palette=\"pastel\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It can also be useful to combine :func:`swarmplot` or :func:`stripplot` with a box plot or violin plot to show each observation along with a summary of the distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.catplot(data=tips, x=\"day\", y=\"total_bill\", kind=\"violin\", inner=None)\n", "sns.swarmplot(data=tips, x=\"day\", y=\"total_bill\", color=\"k\", size=3, ax=g.ax)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Estimating central tendency\n", "---------------------------\n", "\n", "For other applications, rather than showing the distribution within each category, you might want to show an estimate of the central tendency of the values. Seaborn has two main ways to show this information. Importantly, the basic API for these functions is identical to that for the ones discussed above.\n", "\n", "Bar plots\n", "^^^^^^^^^\n", "\n", "A familiar style of plot that accomplishes this goal is a bar plot. In seaborn, the :func:`barplot` function operates on a full dataset and applies a function to obtain the estimate (taking the mean by default). When there are multiple observations in each category, it also uses bootstrapping to compute a confidence interval around the estimate, which is plotted using error bars:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "titanic = sns.load_dataset(\"titanic\")\n", "sns.catplot(data=titanic, x=\"sex\", y=\"survived\", hue=\"class\", kind=\"bar\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "The default error bars show 95% confidence intervals, but (starting in v0.12), it is possible to select from a number of other representations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=titanic, x=\"age\", y=\"deck\", errorbar=(\"pi\", 95), kind=\"bar\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A special case for the bar plot is when you want to show the number of observations in each category rather than computing a statistic for a second variable. This is similar to a histogram over a categorical, rather than quantitative, variable. In seaborn, it's easy to do so with the :func:`countplot` function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=titanic, x=\"deck\", kind=\"count\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Both :func:`barplot` and :func:`countplot` can be invoked with all of the options discussed above, along with others that are demonstrated in the detailed documentation for each function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=titanic, y=\"deck\", hue=\"class\", kind=\"count\",\n", " palette=\"pastel\", edgecolor=\".6\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Point plots\n", "^^^^^^^^^^^\n", "\n", "An alternative style for visualizing the same information is offered by the :func:`pointplot` function. This function also encodes the value of the estimate with height on the other axis, but rather than showing a full bar, it plots the point estimate and confidence interval. Additionally, :func:`pointplot` connects points from the same ``hue`` category. This makes it easy to see how the main relationship is changing as a function of the hue semantic, because your eyes are quite good at picking up on differences of slopes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=titanic, x=\"sex\", y=\"survived\", hue=\"class\", kind=\"point\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "While the categorical functions lack the ``style`` semantic of the relational functions, it can still be a good idea to vary the marker and/or linestyle along with the hue to make figures that are maximally accessible and reproduce well in black and white:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=titanic, x=\"class\", y=\"survived\", hue=\"sex\",\n", " palette={\"male\": \"g\", \"female\": \"m\"},\n", " markers=[\"^\", \"o\"], linestyles=[\"-\", \"--\"],\n", " kind=\"point\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Showing additional dimensions\n", "-----------------------------\n", "\n", "Just like :func:`relplot`, the fact that :func:`catplot` is built on a :class:`FacetGrid` means that it is easy to add faceting variables to visualize higher-dimensional relationships:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(\n", " data=tips, x=\"day\", y=\"total_bill\", hue=\"smoker\",\n", " kind=\"swarm\", col=\"time\", aspect=.7,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "For further customization of the plot, you can use the methods on the :class:`FacetGrid` object that it returns:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.catplot(\n", " data=titanic,\n", " x=\"fare\", y=\"embark_town\", row=\"class\",\n", " kind=\"box\", orient=\"h\",\n", " sharex=False, margin_titles=True,\n", " height=1.5, aspect=4,\n", ")\n", "g.set(xlabel=\"Fare\", ylabel=\"\")\n", "g.set_titles(row_template=\"{row_name} class\")\n", "for ax in g.axes.flat:\n", " ax.xaxis.set_major_formatter('${x:.0f}')" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/color_palettes.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ".. _palette_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Choosing color palettes\n", "=======================\n", "\n", "Seaborn makes it easy to use colors that are well-suited to the characteristics of your data and your visualization goals. This chapter discusses both the general principles that should guide your choices and the tools in seaborn that help you quickly find the best solution for a given application." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import matplotlib as mpl\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set_theme(style=\"white\", rc={\"xtick.major.pad\": 1, \"ytick.major.pad\": 1})" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "np.random.seed(sum(map(ord, \"palettes\")))" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "# Add colormap display methods to matplotlib colormaps.\n", "# These are forthcoming in matplotlib 3.4, but, the matplotlib display\n", "# method includes the colormap name, which is redundant.\n", "def _repr_png_(self):\n", " \"\"\"Generate a PNG representation of the Colormap.\"\"\"\n", " import io\n", " from PIL import Image\n", " import numpy as np\n", " IMAGE_SIZE = (400, 50)\n", " X = np.tile(np.linspace(0, 1, IMAGE_SIZE[0]), (IMAGE_SIZE[1], 1))\n", " pixels = self(X, bytes=True)\n", " png_bytes = io.BytesIO()\n", " Image.fromarray(pixels).save(png_bytes, format='png')\n", " return png_bytes.getvalue()\n", " \n", "def _repr_html_(self):\n", " \"\"\"Generate an HTML representation of the Colormap.\"\"\"\n", " import base64\n", " png_bytes = self._repr_png_()\n", " png_base64 = base64.b64encode(png_bytes).decode('ascii')\n", " return ('')\n", " \n", "import matplotlib as mpl\n", "mpl.colors.Colormap._repr_png_ = _repr_png_\n", "mpl.colors.Colormap._repr_html_ = _repr_html_" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "General principles for using color in plots\n", "-------------------------------------------\n", "\n", "Components of color\n", "~~~~~~~~~~~~~~~~~~~\n", "\n", "Because of the way our eyes work, a particular color can be defined using three components. We usually program colors in a computer by specifying their RGB values, which set the intensity of the red, green, and blue channels in a display. But for analyzing the perceptual attributes of a color, it's better to think in terms of *hue*, *saturation*, and *luminance* channels.\n", "\n", "Hue is the component that distinguishes \"different colors\" in a non-technical sense. It's property of color that leads to first-order names like \"red\" and \"blue\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "sns.husl_palette(8, s=.7)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Saturation (or chroma) is the *colorfulness*. Two colors with different hues will look more distinct when they have more saturation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "c = sns.color_palette(\"muted\")[0]\n", "sns.blend_palette([sns.desaturate(c, 0), c], 8)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "And lightness corresponds to how much light is emitted (or reflected, for printed colors), ranging from black to white:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "sns.blend_palette([\".1\", c, \".95\"], 8)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Vary hue to distinguish categories\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "When you want to represent multiple categories in a plot, you typically should vary the color of the elements. Consider this simple example: in which of these two plots is it easier to count the number of triangular points?" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "n = 45\n", "rng = np.random.default_rng(200)\n", "x = rng.uniform(0, 1, n * 2)\n", "y = rng.uniform(0, 1, n * 2)\n", "a = np.concatenate([np.zeros(n * 2 - 10), np.ones(10)])\n", "\n", "f, axs = plt.subplots(1, 2, figsize=(7, 3.5), sharey=True, sharex=True)\n", "\n", "sns.scatterplot(\n", " x=x[::2], y=y[::2], style=a[::2], size=a[::2], legend=False,\n", " markers=[\"o\", (3, 1, 1)], sizes=[70, 140], ax=axs[0],\n", ")\n", "\n", "sns.scatterplot(\n", " x=x[1::2], y=y[1::2], style=a[1::2], size=a[1::2], hue=a[1::2], legend=False,\n", " markers=[\"o\", (3, 1, 1)], sizes=[70, 140], ax=axs[1],\n", ")\n", "\n", "f.tight_layout(w_pad=2)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In the plot on the right, the orange triangles \"pop out\", making it easy to distinguish them from the circles. This pop-out effect happens because our visual system prioritizes color differences.\n", "\n", "The blue and orange colors differ mostly in terms of their hue. Hue is useful for representing categories: most people can distinguish a moderate number of hues relatively easily, and points that have different hues but similar brightness or intensity seem equally important. It also makes plots easier to talk about. Consider this example:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "b = np.tile(np.arange(10), n // 5)\n", "\n", "f, axs = plt.subplots(1, 2, figsize=(7, 3.5), sharey=True, sharex=True)\n", "\n", "sns.scatterplot(\n", " x=x[::2], y=y[::2], hue=b[::2],\n", " legend=False, palette=\"muted\", s=70, ax=axs[0],\n", ")\n", "\n", "sns.scatterplot(\n", " x=x[1::2], y=y[1::2], hue=b[1::2],\n", " legend=False, palette=\"blend:.75,C0\", s=70, ax=axs[1],\n", ")\n", "\n", "f.tight_layout(w_pad=2)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Most people would be able to quickly ascertain that there are five distinct categories in the plot on the left and, if asked to characterize the \"blue\" points, would be able to do so.\n", "\n", "With the plot on the right, where the points are all blue but vary in their luminance and saturation, it's harder to say how many unique categories are present. And how would we talk about a particular category? \"The fairly-but-not-too-blue points?\" What's more, the gray dots seem to fade into the background, de-emphasizing them relative to the more intense blue dots. If the categories are equally important, this is a poor representation.\n", "\n", "So as a general rule, use hue variation to represent categories. With that said, here are few notes of caution. If you have more than a handful of colors in your plot, it can become difficult to keep in mind what each one means, unless there are pre-existing associations between the categories and the colors used to represent them. This makes your plot harder to interpret: rather than focusing on the data, a viewer will have to continually refer to the legend to make sense of what is shown. So you should strive not to make plots that are too complex. And be mindful that not everyone sees colors the same way. Varying both shape (or some other attribute) and color can help people with anomalous color vision understand your plots, and it can keep them (somewhat) interpretable if they are printed to black-and-white." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Vary luminance to represent numbers\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "On the other hand, hue variations are not well suited to representing numeric data. Consider this example, where we need colors to represent the counts in a bivariate histogram. On the left, we use a circular colormap, where gradual changes in the number of observation within each bin correspond to gradual changes in hue. On the right, we use a palette that uses brighter colors to represent bins with larger counts:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "\n", "f, axs = plt.subplots(1, 2, figsize=(7, 4.25), sharey=True, sharex=True)\n", "\n", "sns.histplot(\n", " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " binwidth=(3, .75), cmap=\"hls\", ax=axs[0],\n", " cbar=True, cbar_kws=dict(orientation=\"horizontal\", pad=.1),\n", ")\n", "axs[0].set(xlabel=\"\", ylabel=\"\")\n", "\n", "\n", "sns.histplot(\n", " data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " binwidth=(3, .75), cmap=\"flare_r\", ax=axs[1],\n", " cbar=True, cbar_kws=dict(orientation=\"horizontal\", pad=.1),\n", ")\n", "axs[1].set(xlabel=\"\", ylabel=\"\")\n", "\n", "f.tight_layout(w_pad=3)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "With the hue-based palette, it's quite difficult to ascertain the shape of the bivariate distribution. In contrast, the luminance palette makes it much more clear that there are two prominent peaks.\n", "\n", "Varying luminance helps you see structure in data, and changes in luminance are more intuitively processed as changes in importance. But the plot on the right does not use a grayscale colormap. Its colorfulness makes it more interesting, and the subtle hue variation increases the perceptual distance between two values. As a result, small differences slightly easier to resolve.\n", "\n", "These examples show that color palette choices are about more than aesthetics: the colors you choose can reveal patterns in your data if used effectively or hide them if used poorly. There is not one optimal palette, but there are palettes that are better or worse for particular datasets and visualization approaches.\n", "\n", "And aesthetics do matter: the more that people want to look at your figures, the greater the chance that they will learn something from them. This is true even when you are making plots for yourself. During exploratory data analysis, you may generate many similar figures. Varying the color palettes will add a sense of novelty, which keeps you engaged and prepared to notice interesting features of your data.\n", "\n", "So how can you choose color palettes that both represent your data well and look attractive?" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Tools for choosing color palettes\n", "---------------------------------\n", "\n", "The most important function for working with color palettes is, aptly, :func:`color_palette`. This function provides an interface to most of the possible ways that one can generate color palettes in seaborn. And it's used internally by any function that has a ``palette`` argument.\n", "\n", "The primary argument to :func:`color_palette` is usually a string: either the name of a specific palette or the name of a family and additional arguments to select a specific member. In the latter case, :func:`color_palette` will delegate to more specific function, such as :func:`cubehelix_palette`. It's also possible to pass a list of colors specified any way that matplotlib accepts (an RGB tuple, a hex code, or a name in the X11 table). The return value is an object that wraps a list of RGB tuples with a few useful methods, such as conversion to hex codes and a rich HTML representation.\n", "\n", "Calling :func:`color_palette` with no arguments will return the current default color palette that matplotlib (and most seaborn functions) will use if colors are not otherwise specified. This default palette can be set with the corresponding :func:`set_palette` function, which calls :func:`color_palette` internally and accepts the same arguments.\n", "\n", "To motivate the different options that :func:`color_palette` provides, it will be useful to introduce a classification scheme for color palettes. Broadly, palettes fall into one of three categories:\n", "\n", "- qualitative palettes, good for representing categorical data\n", "- sequential palettes, good for representing numeric data\n", "- diverging palettes, good for representing numeric data with a categorical boundary" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _qualitative_palettes:\n", "\n", "Qualitative color palettes\n", "--------------------------\n", "\n", "Qualitative palettes are well-suited to representing categorical data because most of their variation is in the hue component. The default color palette in seaborn is a qualitative palette with ten distinct hues:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "These colors have the same ordering as the default matplotlib color palette, ``\"tab10\"``, but they are a bit less intense. Compare:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"tab10\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Seaborn in fact has six variations of matplotlib's palette, called ``deep``, ``muted``, ``pastel``, ``bright``, ``dark``, and ``colorblind``. These span a range of average luminance and saturation values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "import io\n", "from IPython.display import SVG\n", "f = mpl.figure.Figure(figsize=(6, 6))\n", "\n", "ax_locs = dict(\n", " deep=(.4, .4),\n", " bright=(.8, .8),\n", " muted=(.49, .71),\n", " dark=(.8, .2),\n", " pastel=(.2, .8),\n", " colorblind=(.71, .49),\n", ")\n", "\n", "s = .35\n", "\n", "for pal, (x, y) in ax_locs.items():\n", " ax = f.add_axes([x - s / 2, y - s / 2, s, s])\n", " ax.pie(np.ones(10),\n", " colors=sns.color_palette(pal, 10),\n", " counterclock=False, startangle=180,\n", " wedgeprops=dict(linewidth=1, edgecolor=\"w\"))\n", " f.text(x, y, pal, ha=\"center\", va=\"center\", size=14,\n", " bbox=dict(facecolor=\"white\", alpha=0.85, boxstyle=\"round,pad=0.2\"))\n", "\n", "f.text(.1, .05, \"Saturation\", size=18, ha=\"left\", va=\"center\",\n", " bbox=dict(facecolor=\"white\", edgecolor=\"w\"))\n", "f.text(.05, .1, \"Luminance\", size=18, ha=\"center\", va=\"bottom\", rotation=90,\n", " bbox=dict(facecolor=\"white\", edgecolor=\"w\"))\n", "\n", "ax = f.add_axes([0, 0, 1, 1])\n", "ax.set_axis_off()\n", "ax.arrow(.15, .05, .4, 0, width=.002, head_width=.015, color=\".15\")\n", "ax.arrow(.05, .15, 0, .4, width=.002, head_width=.015, color=\".15\")\n", "ax.set(xlim=(0, 1), ylim=(0, 1))\n", "f.savefig(svg:=io.StringIO(), format=\"svg\")\n", "SVG(svg.getvalue())" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Many people find the moderated hues of the default ``\"deep\"`` palette to be aesthetically pleasing, but they are also less distinct. As a result, they may be more difficult to discriminate in some contexts, which is something to keep in mind when making publication graphics. `This comparison `_ can be helpful for estimating how the seaborn color palettes perform when simulating different forms of colorblindess." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Using circular color systems\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "When you have an arbitrary number of categories, the easiest approach to finding unique hues is to draw evenly-spaced colors in a circular color space (one where the hue changes while keeping the brightness and saturation constant). This is what most seaborn functions default to when they need to use more colors than are currently set in the default color cycle.\n", "\n", "The most common way to do this uses the ``hls`` color space, which is a simple transformation of RGB values. We saw this color palette before as a counterexample for how to plot a histogram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"hls\", 8)" ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ "Because of the way the human visual system works, colors that have the same luminance and saturation in terms of their RGB values won't necessarily look equally intense To remedy this, seaborn provides an interface to the `husl `_ system (since renamed to HSLuv), which achieves less intensity variation as you rotate around the color wheel:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"husl\", 8)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When seaborn needs a categorical palette with more colors than are available in the current default, it will use this approach.\n", "\n", "Using categorical Color Brewer palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Another source of visually pleasing categorical palettes comes from the `Color Brewer `_ tool (which also has sequential and diverging palettes, as we'll see below)." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Set2\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Be aware that the qualitative Color Brewer palettes have different lengths, and the default behavior of :func:`color_palette` is to give you the full list:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Paired\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _sequential_palettes:\n", "\n", "Sequential color palettes\n", "-------------------------\n", "\n", "The second major class of color palettes is called \"sequential\". This kind of mapping is appropriate when data range from relatively low or uninteresting values to relatively high or interesting values (or vice versa). As we saw above, the primary dimension of variation in a sequential palette is luminance. Some seaborn functions will default to a sequential palette when you are mapping numeric data. (For historical reasons, both categorical and numeric mappings are specified with the ``hue`` parameter in functions like :func:`relplot` or :func:`displot`, even though numeric mappings use color palettes with relatively little hue variation).\n", "\n", "Perceptually uniform palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Because they are intended to represent numeric values, the best sequential palettes will be *perceptually uniform*, meaning that the relative discriminability of two colors is proportional to the difference between the corresponding data values. Seaborn includes four perceptually uniform sequential colormaps: ``\"rocket\"``, ``\"mako\"``, ``\"flare\"``, and ``\"crest\"``. The first two have a very wide luminance range and are well suited for applications such as heatmaps, where colors fill the space they are plotted into:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"rocket\", as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"mako\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Because the extreme values of these colormaps approach white, they are not well-suited for coloring elements such as lines or points: it will be difficult to discriminate important values against a white or gray background. The \"flare\" and \"crest\" colormaps are a better choice for such plots. They have a more restricted range of luminance variations, which they compensate for with a slightly more pronounced variation in hue. The default direction of the luminance ramp is also reversed, so that smaller values have lighter colors:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"flare\", as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"crest\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It is also possible to use the perceptually uniform colormaps provided by matplotlib, such as ``\"magma\"`` and ``\"viridis\"``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"magma\", as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"viridis\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As with the convention in matplotlib, every continuous colormap has a reversed version, which has the suffix ``\"_r\"``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"rocket_r\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Discrete vs. continuous mapping\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "One thing to be aware of is that seaborn can generate discrete values from sequential colormaps and, when doing so, it will not use the most extreme values. Compare the discrete version of ``\"rocket\"`` against the continuous version shown above:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"rocket\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Internally, seaborn uses the discrete version for categorical data and the continuous version when in numeric mapping mode. Discrete sequential colormaps can be well-suited for visualizing categorical data with an intrinsic ordering, especially if there is some hue variation." ] }, { "cell_type": "raw", "metadata": { "raw_mimetype": "text/restructuredtext" }, "source": [ ".. _cubehelix_palettes:\n", "\n", "Sequential \"cubehelix\" palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "The perceptually uniform colormaps are difficult to programmatically generate, because they are not based on the RGB color space. The `cubehelix `_ system offers an RGB-based compromise: it generates sequential palettes with a linear increase or decrease in brightness and some continuous variation in hue. While not perfectly perceptually uniform, the resulting colormaps have many good properties. Importantly, many aspects of the design process are parameterizable.\n", "\n", "Matplotlib has the default cubehelix version built into it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"cubehelix\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The default palette returned by the seaborn :func:`cubehelix_palette` function is a bit different from the matplotlib default in that it does not rotate as far around the hue wheel or cover as wide a range of intensities. It also reverses the luminance ramp:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Other arguments to :func:`cubehelix_palette` control how the palette looks. The two main things you'll change are the ``start`` (a value between 0 and 3) and ``rot``, or number of rotations (an arbitrary value, but usually between -1 and 1)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(start=.5, rot=-.5, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The more you rotate, the more hue variation you will see:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(start=.5, rot=-.75, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You can control both how dark and light the endpoints are and their order:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.cubehelix_palette(start=2, rot=0, dark=0, light=.95, reverse=True, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The :func:`color_palette` accepts a string code, starting with ``\"ch:\"``, for generating an arbitrary cubehelix palette. You can passs the names of parameters in the string:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"ch:start=.2,rot=-.3\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "And for compactness, each parameter can be specified with its first letter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"ch:s=-.2,r=.6\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Custom sequential palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "For a simpler interface to custom sequential palettes, you can use :func:`light_palette` or :func:`dark_palette`, which are both seeded with a single color and produce a palette that ramps either from light or dark desaturated values to that color:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.light_palette(\"seagreen\", as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.dark_palette(\"#69d\", reverse=True, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As with cubehelix palettes, you can also specify light or dark palettes through :func:`color_palette` or anywhere ``palette`` is accepted:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"light:b\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Reverse the colormap by adding ``\"_r\"``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"dark:salmon_r\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Sequential Color Brewer palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "The Color Brewer library also has some good options for sequential palettes. They include palettes with one primary hue:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Blues\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Along with multi-hue options:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"YlOrBr\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _diverging_palettes:\n", "\n", "Diverging color palettes\n", "------------------------\n", "\n", "The third class of color palettes is called \"diverging\". These are used for data where both large low and high values are interesting and span a midpoint value (often 0) that should be de-emphasized. The rules for choosing good diverging palettes are similar to good sequential palettes, except now there should be two dominant hues in the colormap, one at (or near) each pole. It's also important that the starting values are of similar brightness and saturation.\n", "\n", "Perceptually uniform diverging palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Seaborn includes two perceptually uniform diverging palettes: ``\"vlag\"`` and ``\"icefire\"``. They both use blue and red at their poles, which many intuitively processes as \"cold\" and \"hot\":" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"vlag\", as_cmap=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"icefire\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Custom diverging palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "You can also use the seaborn function :func:`diverging_palette` to create a custom colormap for diverging data. This function makes diverging palettes using the ``husl`` color system. You pass it two hues (in degrees) and, optionally, the lightness and saturation values for the extremes. Using ``husl`` means that the extreme values, and the resulting ramps to the midpoint, while not perfectly perceptually uniform, will be well-balanced:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(220, 20, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This is convenient when you want to stray from the boring confines of cold-hot approaches:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(145, 300, s=60, as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's also possible to make a palette where the midpoint is dark rather than light:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.diverging_palette(250, 30, l=65, center=\"dark\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's important to emphasize here that using red and green, while intuitive, `should be avoided `_.\n", "\n", "Other diverging palettes\n", "~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "There are a few other good diverging palettes built into matplotlib, including Color Brewer palettes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"Spectral\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "And the ``coolwarm`` palette, which has less contrast between the middle values and the extremes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.color_palette(\"coolwarm\", as_cmap=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As you can see, there are many options for using color in your visualizations. Seaborn tries both to use good defaults and to offer a lot of flexibility.\n", "\n", "This discussion is only the beginning, and there are a number of good resources for learning more about techniques for using color in visualizations. One great example is this `series of blog posts `_ from the NASA Earth Observatory. The matplotlib docs also have a `nice tutorial `_ that illustrates some of the perceptual properties of their colormaps." ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/data_structure.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _data_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Data structures accepted by seaborn\n", "===================================\n", "\n", "As a data visualization library, seaborn requires that you provide it with data. This chapter explains the various ways to accomplish that task. Seaborn supports several different dataset formats, and most functions accept data represented with objects from the `pandas `_ or `numpy `_ libraries as well as built-in Python types like lists and dictionaries. Understanding the usage patterns associated with these different options will help you quickly create useful visualizations for nearly any dataset.\n", "\n", ".. note::\n", " As of current writing (v0.13.0), the full breadth of options covered here are supported by most, but not all, of the functions in seaborn. Namely, a few older functions (e.g., :func:`lmplot` and :func:`regplot`) are more limited in what they accept." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Long-form vs. wide-form data\n", "----------------------------\n", "\n", "Most plotting functions in seaborn are oriented towards *vectors* of data. When plotting ``x`` against ``y``, each variable should be a vector. Seaborn accepts data *sets* that have more than one vector organized in some tabular fashion. There is a fundamental distinction between \"long-form\" and \"wide-form\" data tables, and seaborn will treat each differently.\n", "\n", "Long-form data\n", "~~~~~~~~~~~~~~\n", "\n", "A long-form data table has the following characteristics:\n", "\n", "- Each variable is a column\n", "- Each observation is a row" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As a simple example, consider the \"flights\" dataset, which records the number of airline passengers who flew in each month from 1949 to 1960. This dataset has three variables (*year*, *month*, and number of *passengers*):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights = sns.load_dataset(\"flights\")\n", "flights.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "With long-form data, columns in the table are given roles in the plot by explicitly assigning them to one of the variables. For example, making a monthly plot of the number of passengers per year looks like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=flights, x=\"year\", y=\"passengers\", hue=\"month\", kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The advantage of long-form data is that it lends itself well to this explicit specification of the plot. It can accommodate datasets of arbitrary complexity, so long as the variables and observations can be clearly defined. But this format takes some getting used to, because it is often not the model of the data that one has in their head.\n", "\n", "Wide-form data\n", "~~~~~~~~~~~~~~\n", "\n", "For simple datasets, it is often more intuitive to think about data the way it might be viewed in a spreadsheet, where the columns and rows contain *levels* of different variables. For example, we can convert the flights dataset into a wide-form organization by \"pivoting\" it so that each column has each month's time series over years:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights_wide = flights.pivot(index=\"year\", columns=\"month\", values=\"passengers\")\n", "flights_wide.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Here we have the same three variables, but they are organized differently. The variables in this dataset are linked to the *dimensions* of the table, rather than to named fields. Each observation is defined by both the value at a cell in the table and the coordinates of that cell with respect to the row and column indices." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "With long-form data, we can access variables in the dataset by their name. That is not the case with wide-form data. Nevertheless, because there is a clear association between the dimensions of the table and the variable in the dataset, seaborn is able to assign those variables roles in the plot.\n", "\n", ".. note::\n", " Seaborn treats the argument to ``data`` as wide form when neither ``x`` nor ``y`` are assigned." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=flights_wide, kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This plot looks very similar to the one before. Seaborn has assigned the index of the dataframe to ``x``, the values of the dataframe to ``y``, and it has drawn a separate line for each month. There is a notable difference between the two plots, however. When the dataset went through the \"pivot\" operation that converted it from long-form to wide-form, the information about what the values mean was lost. As a result, there is no y axis label. (The lines also have dashes here, because :func:`relplot` has mapped the column variable to both the ``hue`` and ``style`` semantic so that the plot is more accessible. We didn't do that in the long-form case, but we could have by setting ``style=\"month\"``).\n", "\n", "Thus far, we did much less typing while using wide-form data and made nearly the same plot. This seems easier! But a big advantage of long-form data is that, once you have the data in the correct format, you no longer need to think about its *structure*. You can design your plots by thinking only about the variables contained within it. For example, to draw lines that represent the monthly time series for each year, simply reassign the variables:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=flights, x=\"month\", y=\"passengers\", hue=\"year\", kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To achieve the same remapping with the wide-form dataset, we would need to transpose the table:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=flights_wide.transpose(), kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "(This example also illustrates another wrinkle, which is that seaborn currently considers the column variable in a wide-form dataset to be categorical regardless of its datatype, whereas, because the long-form variable is numeric, it is assigned a quantitative color palette and legend. This may change in the future).\n", "\n", "The absence of explicit variable assignments also means that each plot type needs to define a fixed mapping between the dimensions of the wide-form data and the roles in the plot. Because this natural mapping may vary across plot types, the results are less predictable when using wide-form data. For example, the :ref:`categorical ` plots assign the *column* dimension of the table to ``x`` and then aggregate across the rows (ignoring the index):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=flights_wide, kind=\"box\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When using pandas to represent wide-form data, you are limited to just a few variables (no more than three). This is because seaborn does not make use of multi-index information, which is how pandas represents additional variables in a tabular format. The `xarray `_ project offers labeled N-dimensional array objects, which can be considered a generalization of wide-form data to higher dimensions. At present, seaborn does not directly support objects from ``xarray``, but they can be transformed into a long-form :class:`pandas.DataFrame` using the ``to_pandas`` method and then plotted in seaborn like any other long-form data set.\n", "\n", "In summary, we can think of long-form and wide-form datasets as looking something like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "import matplotlib.pyplot as plt\n", "f = plt.figure(figsize=(7, 5))\n", "\n", "gs = plt.GridSpec(\n", " ncols=6, nrows=2, figure=f,\n", " left=0, right=.35, bottom=0, top=.9,\n", " height_ratios=(1, 20),\n", " wspace=.1, hspace=.01\n", ")\n", "\n", "colors = [c + (.5,) for c in sns.color_palette()]\n", "\n", "f.add_subplot(gs[0, :], facecolor=\".8\")\n", "[\n", " f.add_subplot(gs[1:, i], facecolor=colors[i])\n", " for i in range(gs.ncols)\n", "]\n", "\n", "gs = plt.GridSpec(\n", " ncols=2, nrows=2, figure=f,\n", " left=.4, right=1, bottom=.2, top=.8,\n", " height_ratios=(1, 8), width_ratios=(1, 11),\n", " wspace=.015, hspace=.02\n", ")\n", "\n", "f.add_subplot(gs[0, 1:], facecolor=colors[2])\n", "f.add_subplot(gs[1:, 0], facecolor=colors[1])\n", "f.add_subplot(gs[1, 1], facecolor=colors[0])\n", "\n", "for ax in f.axes:\n", " ax.set(xticks=[], yticks=[])\n", "\n", "f.text(.35 / 2, .91, \"Long-form\", ha=\"center\", va=\"bottom\", size=15)\n", "f.text(.7, .81, \"Wide-form\", ha=\"center\", va=\"bottom\", size=15)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Messy data\n", "~~~~~~~~~~\n", "\n", "Many datasets cannot be clearly interpreted using either long-form or wide-form rules. If datasets that are clearly long-form or wide-form are `\"tidy\" `_, we might say that these more ambiguous datasets are \"messy\". In a messy dataset, the variables are neither uniquely defined by the keys nor by the dimensions of the table. This often occurs with *repeated-measures* data, where it is natural to organize a table such that each row corresponds to the *unit* of data collection. Consider this simple dataset from a psychology experiment in which twenty subjects performed a memory task where they studied anagrams while their attention was either divided or focused:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "anagrams = sns.load_dataset(\"anagrams\")\n", "anagrams" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The attention variable is *between-subjects*, but there is also a *within-subjects* variable: the number of possible solutions to the anagrams, which varied from 1 to 3. The dependent measure is a score of memory performance. These two variables (number and score) are jointly encoded across several columns. As a result, the whole dataset is neither clearly long-form nor clearly wide-form.\n", "\n", "How might we tell seaborn to plot the average score as a function of attention and number of solutions? We'd first need to coerce the data into one of our two structures. Let's transform it to a tidy long-form table, such that each variable is a column and each row is an observation. We can use the method :meth:`pandas.DataFrame.melt` to accomplish this task:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "anagrams_long = anagrams.melt(id_vars=[\"subidr\", \"attnr\"], var_name=\"solutions\", value_name=\"score\")\n", "anagrams_long.head()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Now we can make the plot that we want:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=anagrams_long, x=\"solutions\", y=\"score\", hue=\"attnr\", kind=\"point\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Further reading and take-home points\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "For a longer discussion about tabular data structures, you could read the `\"Tidy Data\" `_ paper by Hadley Whickham. Note that seaborn uses a slightly different set of concepts than are defined in the paper. While the paper associates tidyness with long-form structure, we have drawn a distinction between \"tidy wide-form\" data, where there is a clear mapping between variables in the dataset and the dimensions of the table, and \"messy data\", where no such mapping exists.\n", "\n", "The long-form structure has clear advantages. It allows you to create figures by explicitly assigning variables in the dataset to roles in plot, and you can do so with more than three variables. When possible, try to represent your data with a long-form structure when embarking on serious analysis. Most of the examples in the seaborn documentation will use long-form data. But in cases where it is more natural to keep the dataset wide, remember that seaborn can remain useful." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Options for visualizing long-form data\n", "--------------------------------------\n", "\n", "While long-form data has a precise definition, seaborn is fairly flexible in terms of how it is actually organized across the data structures in memory. The examples in the rest of the documentation will typically use :class:`pandas.DataFrame` objects and reference variables in them by assigning names of their columns to the variables in the plot. But it is also possible to store vectors in a Python dictionary or a class that implements that interface:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights_dict = flights.to_dict()\n", "sns.relplot(data=flights_dict, x=\"year\", y=\"passengers\", hue=\"month\", kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Many pandas operations, such as the split-apply-combine operations of a group-by, will produce a dataframe where information has moved from the columns of the input dataframe to the index of the output. So long as the name is retained, you can still reference the data as normal:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights_avg = flights.groupby(\"year\").mean(numeric_only=True)\n", "sns.relplot(data=flights_avg, x=\"year\", y=\"passengers\", kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Additionally, it's possible to pass vectors of data directly as arguments to ``x``, ``y``, and other plotting variables. If these vectors are pandas objects, the ``name`` attribute will be used to label the plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "year = flights_avg.index\n", "passengers = flights_avg[\"passengers\"]\n", "sns.relplot(x=year, y=passengers, kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Numpy arrays and other objects that implement the Python sequence interface work too, but if they don't have names, the plot will not be as informative without further tweaking:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(x=year.to_numpy(), y=passengers.to_list(), kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Options for visualizing wide-form data\n", "--------------------------------------\n", "\n", "The options for passing wide-form data are even more flexible. As with long-form data, pandas objects are preferable because the name (and, in some cases, index) information can be used. But in essence, any format that can be viewed as a single vector or a collection of vectors can be passed to ``data``, and a valid plot can usually be constructed.\n", "\n", "The example we saw above used a rectangular :class:`pandas.DataFrame`, which can be thought of as a collection of its columns. A dict or list of pandas objects will also work, but we'll lose the axis labels:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights_wide_list = [col for _, col in flights_wide.items()]\n", "sns.relplot(data=flights_wide_list, kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The vectors in a collection do not need to have the same length. If they have an ``index``, it will be used to align them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "two_series = [flights_wide.loc[:1955, \"Jan\"], flights_wide.loc[1952:, \"Aug\"]]\n", "sns.relplot(data=two_series, kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Whereas an ordinal index will be used for numpy arrays or simple Python sequences:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "two_arrays = [s.to_numpy() for s in two_series]\n", "sns.relplot(data=two_arrays, kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "But a dictionary of such vectors will at least use the keys:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "two_arrays_dict = {s.name: s.to_numpy() for s in two_series}\n", "sns.relplot(data=two_arrays_dict, kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Rectangular numpy arrays are treated just like a dataframe without index information, so they are viewed as a collection of column vectors. Note that this is different from how numpy indexing operations work, where a single indexer will access a row. But it is consistent with how pandas would turn the array into a dataframe or how matplotlib would plot it:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "flights_array = flights_wide.to_numpy()\n", "sns.relplot(data=flights_array, kind=\"line\")" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/distributions.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _distribution_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Visualizing distributions of data\n", "==================================\n", "\n", "An early step in any effort to analyze or model data should be to understand how the variables are distributed. Techniques for distribution visualization can provide quick answers to many important questions. What range do the observations cover? What is their central tendency? Are they heavily skewed in one direction? Is there evidence for bimodality? Are there significant outliers? Do the answers to these questions vary across subsets defined by other variables?\n", "\n", "The :ref:`distributions module ` contains several functions designed to answer questions such as these. The axes-level functions are :func:`histplot`, :func:`kdeplot`, :func:`ecdfplot`, and :func:`rugplot`. They are grouped together within the figure-level :func:`displot`, :func:`jointplot`, and :func:`pairplot` functions.\n", "\n", "There are several different approaches to visualizing a distribution, and each has its relative advantages and drawbacks. It is important to understand these factors so that you can choose the best approach for your particular aim." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "import seaborn as sns; sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _tutorial_hist:\n", "\n", "Plotting univariate histograms\n", "------------------------------\n", "\n", "Perhaps the most common approach to visualizing a distribution is the *histogram*. This is the default approach in :func:`displot`, which uses the same underlying code as :func:`histplot`. A histogram is a bar plot where the axis representing the data variable is divided into a set of discrete bins and the count of observations falling within each bin is shown using the height of the corresponding bar:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.displot(penguins, x=\"flipper_length_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This plot immediately affords a few insights about the ``flipper_length_mm`` variable. For instance, we can see that the most common flipper length is about 195 mm, but the distribution appears bimodal, so this one number does not represent the data well.\n", "\n", "Choosing the bin size\n", "^^^^^^^^^^^^^^^^^^^^^\n", "\n", "The size of the bins is an important parameter, and using the wrong bin size can mislead by obscuring important features of the data or by creating apparent features out of random variability. By default, :func:`displot`/:func:`histplot` choose a default bin size based on the variance of the data and the number of observations. But you should not be over-reliant on such automatic approaches, because they depend on particular assumptions about the structure of your data. It is always advisable to check that your impressions of the distribution are consistent across different bin sizes. To choose the size directly, set the `binwidth` parameter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", binwidth=3)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In other circumstances, it may make more sense to specify the *number* of bins, rather than their size:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", bins=20)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "One example of a situation where defaults fail is when the variable takes a relatively small number of integer values. In that case, the default bin width may be too small, creating awkward gaps in the distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.displot(tips, x=\"size\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "One approach would be to specify the precise bin breaks by passing an array to ``bins``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(tips, x=\"size\", bins=[1, 2, 3, 4, 5, 6, 7])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This can also be accomplished by setting ``discrete=True``, which chooses bin breaks that represent the unique values in a dataset with bars that are centered on their corresponding value." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(tips, x=\"size\", discrete=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's also possible to visualize the distribution of a categorical variable using the logic of a histogram. Discrete bins are automatically set for categorical variables, but it may also be helpful to \"shrink\" the bars slightly to emphasize the categorical nature of the axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(tips, x=\"day\", shrink=.8)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Conditioning on other variables\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "Once you understand the distribution of a variable, the next step is often to ask whether features of that distribution differ across other variables in the dataset. For example, what accounts for the bimodal distribution of flipper lengths that we saw above? :func:`displot` and :func:`histplot` provide support for conditional subsetting via the ``hue`` semantic. Assigning a variable to ``hue`` will draw a separate histogram for each of its unique values and distinguish them by color:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, the different histograms are \"layered\" on top of each other and, in some cases, they may be difficult to distinguish. One option is to change the visual representation of the histogram from a bar plot to a \"step\" plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", element=\"step\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Alternatively, instead of layering each bar, they can be \"stacked\", or moved vertically. In this plot, the outline of the full histogram will match the plot with only a single variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The stacked histogram emphasizes the part-whole relationship between the variables, but it can obscure other features (for example, it is difficult to determine the mode of the Adelie distribution). Another option is \"dodge\" the bars, which moves them horizontally and reduces their width. This ensures that there are no overlaps and that the bars remain comparable in terms of height. But it only works well when the categorical variable has a small number of levels:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"sex\", multiple=\"dodge\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Because :func:`displot` is a figure-level function and is drawn onto a :class:`FacetGrid`, it is also possible to draw each individual distribution in a separate subplot by assigning the second variable to ``col`` or ``row`` rather than (or in addition to) ``hue``. This represents the distribution of each subset well, but it makes it more difficult to draw direct comparisons:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", col=\"sex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "None of these approaches are perfect, and we will soon see some alternatives to a histogram that are better-suited to the task of comparison.\n", "\n", "Normalized histogram statistics\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "Before we do, another point to note is that, when the subsets have unequal numbers of observations, comparing their distributions in terms of counts may not be ideal. One solution is to *normalize* the counts using the ``stat`` parameter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", stat=\"density\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "By default, however, the normalization is applied to the entire distribution, so this simply rescales the height of the bars. By setting ``common_norm=False``, each subset will be normalized independently:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", stat=\"density\", common_norm=False)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Density normalization scales the bars so that their *areas* sum to 1. As a result, the density axis is not directly interpretable. Another option is to normalize the bars to that their *heights* sum to 1. This makes most sense when the variable is discrete, but it is an option for all histograms:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", stat=\"probability\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _tutorial_kde:\n", "\n", "Kernel density estimation\n", "-------------------------\n", "\n", "A histogram aims to approximate the underlying probability density function that generated the data by binning and counting observations. Kernel density estimation (KDE) presents a different solution to the same problem. Rather than using discrete bins, a KDE plot smooths the observations with a Gaussian kernel, producing a continuous density estimate:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Choosing the smoothing bandwidth\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "Much like with the bin size in the histogram, the ability of the KDE to accurately represent the data depends on the choice of smoothing bandwidth. An over-smoothed estimate might erase meaningful features, but an under-smoothed estimate can obscure the true shape within random noise. The easiest way to check the robustness of the estimate is to adjust the default bandwidth:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"kde\", bw_adjust=.25)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Note how the narrow bandwidth makes the bimodality much more apparent, but the curve is much less smooth. In contrast, a larger bandwidth obscures the bimodality almost completely:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"kde\", bw_adjust=2)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Conditioning on other variables\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "As with histograms, if you assign a ``hue`` variable, a separate density estimate will be computed for each level of that variable:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In many cases, the layered KDE is easier to interpret than the layered histogram, so it is often a good choice for the task of comparison. Many of the same options for resolving multiple distributions apply to the KDE as well, however:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\", multiple=\"stack\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Note how the stacked plot filled in the area between each curve by default. It is also possible to fill in the curves for single or layered densities, although the default alpha value (opacity) will be different, so that the individual densities are easier to resolve." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"kde\", fill=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Kernel density estimation pitfalls\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "KDE plots have many advantages. Important features of the data are easy to discern (central tendency, bimodality, skew), and they afford easy comparisons between subsets. But there are also situations where KDE poorly represents the underlying data. This is because the logic of KDE assumes that the underlying distribution is smooth and unbounded. One way this assumption can fail is when a variable reflects a quantity that is naturally bounded. If there are observations lying close to the bound (for example, small values of a variable that cannot be negative), the KDE curve may extend to unrealistic values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(tips, x=\"total_bill\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This can be partially avoided with the ``cut`` parameter, which specifies how far the curve should extend beyond the extreme datapoints. But this influences only where the curve is drawn; the density estimate will still smooth over the range where no data can exist, causing it to be artificially low at the extremes of the distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(tips, x=\"total_bill\", kind=\"kde\", cut=0)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The KDE approach also fails for discrete data or when data are naturally continuous but specific values are over-represented. The important thing to keep in mind is that the KDE will *always show you a smooth curve*, even when the data themselves are not smooth. For example, consider this distribution of diamond weights:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "diamonds = sns.load_dataset(\"diamonds\")\n", "sns.displot(diamonds, x=\"carat\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "While the KDE suggests that there are peaks around specific values, the histogram reveals a much more jagged distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(diamonds, x=\"carat\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As a compromise, it is possible to combine these two approaches. While in histogram mode, :func:`displot` (as with :func:`histplot`) has the option of including the smoothed KDE curve (note ``kde=True``, not ``kind=\"kde\"``):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(diamonds, x=\"carat\", kde=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _tutorial_ecdf:\n", "\n", "Empirical cumulative distributions\n", "----------------------------------\n", "\n", "A third option for visualizing distributions computes the \"empirical cumulative distribution function\" (ECDF). This plot draws a monotonically-increasing curve through each datapoint such that the height of the curve reflects the proportion of observations with a smaller value:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", kind=\"ecdf\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ECDF plot has two key advantages. Unlike the histogram or KDE, it directly represents each datapoint. That means there is no bin size or smoothing parameter to consider. Additionally, because the curve is monotonically increasing, it is well-suited for comparing multiple distributions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"flipper_length_mm\", hue=\"species\", kind=\"ecdf\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The major downside to the ECDF plot is that it represents the shape of the distribution less intuitively than a histogram or density curve. Consider how the bimodality of flipper lengths is immediately apparent in the histogram, but to see it in the ECDF plot, you must look for varying slopes. Nevertheless, with practice, you can learn to answer all of the important questions about a distribution by examining the ECDF, and doing so can be a powerful approach." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Visualizing bivariate distributions\n", "-----------------------------------\n", "\n", "All of the examples so far have considered *univariate* distributions: distributions of a single variable, perhaps conditional on a second variable assigned to ``hue``. Assigning a second variable to ``y``, however, will plot a *bivariate* distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A bivariate histogram bins the data within rectangles that tile the plot and then shows the count of observations within each rectangle with the fill color (analogous to a :func:`heatmap`). Similarly, a bivariate KDE plot smoothes the (x, y) observations with a 2D Gaussian. The default representation then shows the *contours* of the 2D density:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Assigning a ``hue`` variable will plot multiple heatmaps or contour sets using different colors. For bivariate histograms, this will only work well if there is minimal overlap between the conditional distributions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The contour approach of the bivariate KDE plot lends itself better to evaluating overlap, although a plot with too many contours can get busy:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Just as with univariate plots, the choice of bin size or smoothing bandwidth will determine how well the plot represents the underlying bivariate distribution. The same parameters apply, but they can be tuned for each variable by passing a pair of values:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", binwidth=(2, .5))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To aid interpretation of the heatmap, add a colorbar to show the mapping between counts and color intensity:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", binwidth=(2, .5), cbar=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The meaning of the bivariate density contours is less straightforward. Because the density is not directly interpretable, the contours are drawn at *iso-proportions* of the density, meaning that each curve shows a level set such that some proportion *p* of the density lies below it. The *p* values are evenly spaced, with the lowest level contolled by the ``thresh`` parameter and the number controlled by ``levels``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"kde\", thresh=.2, levels=4)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The ``levels`` parameter also accepts a list of values, for more control:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\", kind=\"kde\", levels=[.01, .05, .1, .8])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The bivariate histogram allows one or both variables to be discrete. Plotting one discrete and one continuous variable offers another way to compare conditional univariate distributions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(diamonds, x=\"price\", y=\"clarity\", log_scale=(True, False))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In contrast, plotting two discrete variables is an easy to way show the cross-tabulation of the observations:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(diamonds, x=\"color\", y=\"clarity\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Distribution visualization in other settings\n", "--------------------------------------------\n", "\n", "Several other figure-level plotting functions in seaborn make use of the :func:`histplot` and :func:`kdeplot` functions.\n", "\n", "\n", "Plotting joint and marginal distributions\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "The first is :func:`jointplot`, which augments a bivariate relational or distribution plot with the marginal distributions of the two variables. By default, :func:`jointplot` represents the bivariate distribution using :func:`scatterplot` and the marginal distributions using :func:`histplot`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Similar to :func:`displot`, setting a different ``kind=\"kde\"`` in :func:`jointplot` will change both the joint and marginal plots the use :func:`kdeplot`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(\n", " data=penguins,\n", " x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"species\",\n", " kind=\"kde\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ":func:`jointplot` is a convenient interface to the :class:`JointGrid` class, which offeres more flexibility when used directly:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.JointGrid(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "g.plot_joint(sns.histplot)\n", "g.plot_marginals(sns.boxplot)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A less-obtrusive way to show marginal distributions uses a \"rug\" plot, which adds a small tick on the edge of the plot to represent each individual observation. This is built into :func:`displot`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(\n", " penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " kind=\"kde\", rug=True\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "And the axes-level :func:`rugplot` function can be used to add rugs on the side of any other kind of plot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", "sns.rugplot(data=penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Plotting many distributions\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "The :func:`pairplot` function offers a similar blend of joint and marginal distributions. Rather than focusing on a single relationship, however, :func:`pairplot` uses a \"small-multiple\" approach to visualize the univariate distribution of all variables in a dataset along with all of their pairwise relationships:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(penguins)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As with :func:`jointplot`/:class:`JointGrid`, using the underlying :class:`PairGrid` directly will afford more flexibility with only a bit more typing:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins)\n", "g.map_upper(sns.histplot)\n", "g.map_lower(sns.kdeplot, fill=True)\n", "g.map_diag(sns.histplot, kde=True)" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/error_bars.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _errorbar_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib as mpl\n", "import matplotlib.pyplot as plt\n", "sns.set_theme(style=\"darkgrid\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "np.random.seed(sum(map(ord, \"errorbars\")))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Statistical estimation and error bars\n", "=====================================\n", "\n", "Data visualization sometimes involves a step of aggregation or estimation, where multiple data points are reduced to a summary statistic such as the mean or median. When showing a summary statistic, it is usually appropriate to add *error bars*, which provide a visual cue about how well the summary represents the underlying data points.\n", "\n", "Several seaborn functions will automatically calculate both summary statistics and the error bars when given a full dataset. This chapter explains how you can control what the error bars show and why you might choose each of the options that seaborn affords.\n", "\n", "The error bars around an estimate of central tendency can show one of two general things: either the range of uncertainty about the estimate or the spread of the underlying data around it. These measures are related: given the same sample size, estimates will be more uncertain when data has a broader spread. But uncertainty will decrease as sample sizes grow, whereas spread will not.\n", "\n", "In seaborn, there are two approaches for constructing each kind of error bar. One approach is parametric, using a formula that relies on assumptions about the shape of the distribution. The other approach is nonparametric, using only the data that you provide.\n", "\n", "Your choice is made with the `errorbar` parameter, which exists for each function that does estimation as part of plotting. This parameter accepts the name of the method to use and, optionally, a parameter that controls the size of the interval. The choices can be defined in a 2D taxonomy that depends on what is shown and how it is constructed:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "import io\n", "from IPython.display import SVG\n", "f = mpl.figure.Figure(figsize=(8, 5))\n", "axs = f.subplots(2, 2, sharex=True, sharey=True,)\n", "\n", "plt.setp(axs, xlim=(-1, 1), ylim=(-1, 1), xticks=[], yticks=[])\n", "for ax, color in zip(axs.flat, [\"C0\", \"C2\", \"C3\", \"C1\"]):\n", " ax.set_facecolor(mpl.colors.to_rgba(color, .25))\n", "\n", "kws = dict(x=0, y=.2, ha=\"center\", va=\"center\", size=18)\n", "axs[0, 0].text(s=\"Standard deviation\", **kws)\n", "axs[0, 1].text(s=\"Standard error\", **kws)\n", "axs[1, 0].text(s=\"Percentile interval\", **kws)\n", "axs[1, 1].text(s=\"Confidence interval\", **kws)\n", "\n", "kws = dict(x=0, y=-.2, ha=\"center\", va=\"center\", size=18, font=\"Courier New\")\n", "axs[0, 0].text(s='errorbar=(\"sd\", scale)', **kws)\n", "axs[0, 1].text(s='errorbar=(\"se\", scale)', **kws)\n", "axs[1, 0].text(s='errorbar=(\"pi\", width)', **kws)\n", "axs[1, 1].text(s='errorbar=(\"ci\", width)', **kws)\n", "\n", "kws = dict(size=18)\n", "axs[0, 0].set_title(\"Spread\", **kws)\n", "axs[0, 1].set_title(\"Uncertainty\", **kws)\n", "axs[0, 0].set_ylabel(\"Parametric\", **kws)\n", "axs[1, 0].set_ylabel(\"Nonparametric\", **kws)\n", "\n", "f.tight_layout()\n", "f.subplots_adjust(hspace=.05, wspace=.05 * (4 / 6))\n", "f.savefig(svg:=io.StringIO(), format=\"svg\")\n", "SVG(svg.getvalue())" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You will note that the size parameter is defined differently for the parametric and nonparametric approaches. For parametric error bars, it is a scalar factor that is multiplied by the statistic defining the error (standard error or standard deviation). For nonparametric error bars, it is a percentile width. This is explained further for each specific approach below.\n", "\n", "\n", ".. note::\n", " The `errorbar` API described here was introduced in seaborn v0.12. In prior versions, the only options were to show a bootstrap confidence interval or a standard deviation, with the choice controlled by the `ci` parameter (i.e., `ci=` or `ci=\"sd\"`).\n", "\n", "To compare the different parameterizations, we'll use the following helper function:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def plot_errorbars(arg, **kws):\n", " np.random.seed(sum(map(ord, \"error_bars\")))\n", " x = np.random.normal(0, 1, 100)\n", " f, axs = plt.subplots(2, figsize=(7, 2), sharex=True, layout=\"tight\")\n", " sns.pointplot(x=x, errorbar=arg, **kws, capsize=.3, ax=axs[0])\n", " sns.stripplot(x=x, jitter=.3, ax=axs[1])" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Measures of data spread\n", "-----------------------\n", "\n", "Error bars that represent data spread present a compact display of the distribution, using three numbers where :func:`boxplot` would use 5 or more and :func:`violinplot` would use a complicated algorithm.\n", "\n", "Standard deviation error bars\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Standard deviation error bars are the simplest to explain, because the standard deviation is a familiar statistic. It is the average distance from each data point to the sample mean. By default, `errorbar=\"sd\"` will draw error bars at +/- 1 sd around the estimate, but the range can be increased by passing a scaling size parameter. Note that, assuming normally-distributed data, ~68% of the data will lie within one standard deviation, ~95% will lie within two, and ~99.7% will lie within three:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars(\"sd\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Percentile interval error bars\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Percentile intervals also represent the range where some amount of the data fall, but they do so by \n", "computing those percentiles directly from your sample. By default, `errorbar=\"pi\"` will show a 95% interval, ranging from the 2.5 to the 97.5 percentiles. You can choose a different range by passing a size parameter, e.g., to show the inter-quartile range:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars((\"pi\", 50))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The standard deviation error bars will always be symmetrical around the estimate. This can be a problem when the data are skewed, especially if there are natural bounds (e.g., if the data represent a quantity that can only be positive). In some cases, standard deviation error bars may extend to \"impossible\" values. The nonparametric approach does not have this problem, because it can account for asymmetrical spread and will never extend beyond the range of the data." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Measures of estimate uncertainty\n", "--------------------------------\n", "\n", "If your data are a random sample from a larger population, then the mean (or other estimate) will be an imperfect measure of the true population average. Error bars that show estimate uncertainty try to represent the range of likely values for the true parameter.\n", "\n", "Standard error bars\n", "~~~~~~~~~~~~~~~~~~~\n", "\n", "The standard error statistic is related to the standard deviation: in fact it is just the standard deviation divided by the square root of the sample size. The default, with `errorbar=\"se\"`, draws an interval +/-1 standard error from the mean:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars(\"se\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Confidence interval error bars\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "The nonparametric approach to representing uncertainty uses *bootstrapping*: a procedure where the dataset is randomly resampled with replacement a number of times, and the estimate is recalculated from each resample. This procedure creates a distribution of statistics approximating the distribution of values that you could have gotten for your estimate if you had a different sample.\n", "\n", "The confidence interval is constructed by taking a percentile interval of the *bootstrap distribution*. By default `errorbar=\"ci\"` draws a 95% confidence interval:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars(\"ci\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The seaborn terminology is somewhat specific, because a confidence interval in statistics can be parametric or nonparametric. To draw a parametric confidence interval, you scale the standard error, using a formula similar to the one mentioned above. For example, an approximate 95% confidence interval can be constructed by taking the mean +/- two standard errors:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars((\"se\", 2))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The nonparametric bootstrap has advantages similar to those of the percentile interval: it will naturally adapt to skewed and bounded data in a way that a standard error interval cannot. It is also more general. While the standard error formula is specific to the mean, error bars can be computed using the bootstrap for any estimator:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars(\"ci\", estimator=\"median\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Bootstrapping involves randomness, and the error bars will appear slightly different each time you run the code that creates them. A few parameters control this. One sets the number of iterations (`n_boot`): with more iterations, the resulting intervals will be more stable. The other sets the `seed` for the random number generator, which will ensure identical results:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars(\"ci\", n_boot=5000, seed=10)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Because of its iterative process, bootstrap intervals can be expensive to compute, especially for large datasets. But because uncertainty decreases with sample size, it may be more informative in that case to use an error bar that represents data spread.\n", "\n", "Custom error bars\n", "~~~~~~~~~~~~~~~~~\n", "\n", "If these recipes are not sufficient, it is also possible to pass a generic function to the `errorbar` parameter. This function should take a vector and produce a pair of values representing the minimum and maximum points of the interval:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "plot_errorbars(lambda x: (x.min(), x.max()))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "(In practice, you could show the full range of the data with `errorbar=(\"pi\", 100)` rather than the custom function shown above).\n", "\n", "Note that seaborn functions cannot currently draw error bars from values that have been calculated externally, although matplotlib functions can be used to add such error bars to seaborn plots." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Error bars on regression fits\n", "-----------------------------\n", "\n", "The preceding discussion has focused on error bars shown around parameter estimates for aggregate data. Error bars also arise in seaborn when estimating regression models to visualize relationships. Here, the error bars will be represented by a \"band\" around the regression line:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = np.random.normal(0, 1, 50)\n", "y = x * 2 + np.random.normal(0, 2, size=x.size)\n", "sns.regplot(x=x, y=y)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Currently, the error bars on a regression estimate are less flexible, only showing a confidence interval with a size set through `ci=`. This may change in the future." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Are error bars enough?\n", "----------------------\n", "\n", "You should always ask yourself whether it's best to use a plot that displays only a summary statistic and error bar. In many cases, it isn't.\n", "\n", "If you are interested in questions about summaries (such as whether the mean value differs between groups or increases over time), aggregation reduces the complexity of the plot and makes those inferences easier. But in doing so, it obscures valuable information about the underlying data points, such as the shape of the distributions and the presence of outliers.\n", "\n", "When analyzing your own data, don't be satisfied with summary statistics. Always look at the underlying distributions too. Sometimes, it can be helpful to combine both perspectives into the same figure. Many seaborn functions can help with this task, especially those discussed in the :doc:`categorical tutorial `." ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/function_overview.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _function_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Overview of seaborn plotting functions\n", "======================================\n", "\n", "Most of your interactions with seaborn will happen through a set of plotting functions. Later chapters in the tutorial will explore the specific features offered by each function. This chapter will introduce, at a high-level, the different kinds of functions that you will encounter." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "from IPython.display import HTML\n", "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Similar functions for similar tasks\n", "-----------------------------------\n", "\n", "The seaborn namespace is flat; all of the functionality is accessible at the top level. But the code itself is hierarchically structured, with modules of functions that achieve similar visualization goals through different means. Most of the docs are structured around these modules: you'll encounter names like \"relational\", \"distributional\", and \"categorical\".\n", "\n", "For example, the :ref:`distributions module ` defines functions that specialize in representing the distribution of datapoints. This includes familiar methods like the histogram:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.histplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Along with similar, but perhaps less familiar, options such as kernel density estimation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.kdeplot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Functions within a module share a lot of underlying code and offer similar features that may not be present in other components of the library (such as ``multiple=\"stack\"`` in the examples above). They are designed to facilitate switching between different visual representations as you explore a dataset, because different representations often have complementary strengths and weaknesses." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Figure-level vs. axes-level functions\n", "-------------------------------------\n", "\n", "In addition to the different modules, there is a cross-cutting classification of seaborn functions as \"axes-level\" or \"figure-level\". The examples above are axes-level functions. They plot data onto a single :class:`matplotlib.pyplot.Axes` object, which is the return value of the function.\n", "\n", "In contrast, figure-level functions interface with matplotlib through a seaborn object, usually a :class:`FacetGrid`, that manages the figure. Each module has a single figure-level function, which offers a unitary interface to its various axes-level functions. The organization looks a bit like this:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "from matplotlib.patches import FancyBboxPatch\n", "\n", "f, ax = plt.subplots(figsize=(7, 5))\n", "f.subplots_adjust(0, 0, 1, 1)\n", "ax.set_axis_off()\n", "ax.set(xlim=(0, 1), ylim=(0, 1))\n", "\n", "\n", "modules = \"relational\", \"distributions\", \"categorical\"\n", "\n", "pal = sns.color_palette(\"deep\")\n", "colors = dict(relational=pal[0], distributions=pal[1], categorical=pal[2])\n", "\n", "pal = sns.color_palette(\"dark\")\n", "text_colors = dict(relational=pal[0], distributions=pal[1], categorical=pal[2])\n", "\n", "\n", "functions = dict(\n", " relational=[\"scatterplot\", \"lineplot\"],\n", " distributions=[\"histplot\", \"kdeplot\", \"ecdfplot\", \"rugplot\"],\n", " categorical=[\"stripplot\", \"swarmplot\", \"boxplot\", \"violinplot\", \"pointplot\", \"barplot\"],\n", ")\n", "\n", "pad = .06\n", "\n", "w = .2\n", "h = .15\n", "\n", "xs = np.arange(0, 1, 1 / 3) + pad * 1.05\n", "y = .7\n", "\n", "for x, mod in zip(xs, modules):\n", " color = colors[mod] + (.2,)\n", " text_color = text_colors[mod]\n", " box = FancyBboxPatch((x, y), w, h, f\"round,pad={pad}\", color=\"white\")\n", " ax.add_artist(box)\n", " box = FancyBboxPatch((x, y), w, h, f\"round,pad={pad}\", linewidth=1, edgecolor=text_color, facecolor=color)\n", " ax.add_artist(box)\n", " ax.text(x + w / 2, y + h / 2, f\"{mod[:3]}plot\\n({mod})\", ha=\"center\", va=\"center\", size=22, color=text_color)\n", "\n", " for i, func in enumerate(functions[mod]):\n", " x_i = x + w / 2\n", " y_i = y - i * .1 - h / 2 - pad\n", " box = FancyBboxPatch((x_i - w / 2, y_i - pad / 3), w, h / 4, f\"round,pad={pad / 3}\",\n", " color=\"white\")\n", " ax.add_artist(box)\n", " box = FancyBboxPatch((x_i - w / 2, y_i - pad / 3), w, h / 4, f\"round,pad={pad / 3}\",\n", " linewidth=1, edgecolor=text_color, facecolor=color)\n", " ax.add_artist(box)\n", " ax.text(x_i, y_i, func, ha=\"center\", va=\"center\", size=18, color=text_color)\n", "\n", " ax.plot([x_i, x_i], [y, y_i], zorder=-100, color=text_color, lw=1)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "For example, :func:`displot` is the figure-level function for the distributions module. Its default behavior is to draw a histogram, using the same code as :func:`histplot` behind the scenes:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To draw a kernel density plot instead, using the same code as :func:`kdeplot`, select it using the ``kind`` parameter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", multiple=\"stack\", kind=\"kde\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You'll notice that the figure-level plots look mostly like their axes-level counterparts, but there are a few differences. Notably, the legend is placed outside the plot. They also have a slightly different shape (more on that shortly).\n", "\n", "The most useful feature offered by the figure-level functions is that they can easily create figures with multiple subplots. For example, instead of stacking the three distributions for each species of penguins in the same axes, we can \"facet\" them by plotting each distribution across the columns of the figure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=penguins, x=\"flipper_length_mm\", hue=\"species\", col=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The figure-level functions wrap their axes-level counterparts and pass the kind-specific keyword arguments (such as the bin size for a histogram) down to the underlying function. That means they are no less flexible, but there is a downside: the kind-specific parameters don't appear in the function signature or docstrings. Some of their features might be less discoverable, and you may need to look at two different pages of the documentation before understanding how to achieve a specific goal." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Axes-level functions make self-contained plots\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "The axes-level functions are written to act like drop-in replacements for matplotlib functions. While they add axis labels and legends automatically, they don't modify anything beyond the axes that they are drawn into. That means they can be composed into arbitrarily-complex matplotlib figures with predictable results.\n", "\n", "The axes-level functions call :func:`matplotlib.pyplot.gca` internally, which hooks into the matplotlib state-machine interface so that they draw their plots on the \"currently-active\" axes. But they additionally accept an ``ax=`` argument, which integrates with the object-oriented interface and lets you specify exactly where each plot should go:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f, axs = plt.subplots(1, 2, figsize=(8, 4), gridspec_kw=dict(width_ratios=[4, 3]))\n", "sns.scatterplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\", ax=axs[0])\n", "sns.histplot(data=penguins, x=\"species\", hue=\"species\", shrink=.8, alpha=.8, legend=False, ax=axs[1])\n", "f.tight_layout()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Figure-level functions own their figure\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "In contrast, figure-level functions cannot (easily) be composed with other plots. By design, they \"own\" their own figure, including its initialization, so there's no notion of using a figure-level function to draw a plot onto an existing axes. This constraint allows the figure-level functions to implement features such as putting the legend outside of the plot.\n", "\n", "Nevertheless, it is possible to go beyond what the figure-level functions offer by accessing the matplotlib axes on the object that they return and adding other elements to the plot that way:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "g = sns.relplot(data=tips, x=\"total_bill\", y=\"tip\")\n", "g.ax.axline(xy1=(10, 2), slope=.2, color=\"b\", dashes=(5, 2))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Customizing plots from a figure-level function\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "The figure-level functions return a :class:`FacetGrid` instance, which has a few methods for customizing attributes of the plot in a way that is \"smart\" about the subplot organization. For example, you can change the labels on the external axes using a single line of code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.relplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", col=\"sex\")\n", "g.set_axis_labels(\"Flipper length (mm)\", \"Bill length (mm)\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "While convenient, this does add a bit of extra complexity, as you need to remember that this method is not part of the matplotlib API and exists only when using a figure-level function." ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _figure_size_tutorial:\n", "\n", "Specifying figure sizes\n", "^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "To increase or decrease the size of a matplotlib plot, you set the width and height of the entire figure, either in the `global rcParams `_, while setting up the plot (e.g. with the ``figsize`` parameter of :func:`matplotlib.pyplot.subplots`), or by calling a method on the figure object (e.g. :meth:`matplotlib.Figure.set_size_inches`). When using an axes-level function in seaborn, the same rules apply: the size of the plot is determined by the size of the figure it is part of and the axes layout in that figure.\n", "\n", "When using a figure-level function, there are several key differences. First, the functions themselves have parameters to control the figure size (although these are actually parameters of the underlying :class:`FacetGrid` that manages the figure). Second, these parameters, ``height`` and ``aspect``, parameterize the size slightly differently than the ``width``, ``height`` parameterization in matplotlib (using the seaborn parameters, ``width = height * aspect``). Most importantly, the parameters correspond to the size of each *subplot*, rather than the size of the overall figure.\n", "\n", "To illustrate the difference between these approaches, here is the default output of :func:`matplotlib.pyplot.subplots` with one subplot:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f, ax = plt.subplots()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A figure with multiple columns will have the same overall size, but the axes will be squeezed horizontally to fit in the space:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "f, ax = plt.subplots(1, 2, sharey=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In contrast, a plot created by a figure-level function will be square. To demonstrate that, let's set up an empty plot by using :class:`FacetGrid` directly. This happens behind the scenes in functions like :func:`relplot`, :func:`displot`, or :func:`catplot`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(penguins)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When additional columns are added, the figure itself will become wider, so that its subplots have the same size and shape:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(penguins, col=\"sex\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "And you can adjust the size and shape of each subplot without accounting for the total number of rows and columns in the figure:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.FacetGrid(penguins, col=\"sex\", height=3.5, aspect=.75)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The upshot is that you can assign faceting variables without stopping to think about how you'll need to adjust the total figure size. A downside is that, when you do want to change the figure size, you'll need to remember that things work a bit differently than they do in matplotlib." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Relative merits of figure-level functions\n", "^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^\n", "\n", "Here is a summary of the pros and cons that we have discussed above:\n", "\n", ".. list-table::\n", " :header-rows: 1\n", "\n", " * - Advantages\n", " - Drawbacks\n", " * - Easy faceting by data variables\n", " - Many parameters not in function signature\n", " * - Legend outside of plot by default\n", " - Cannot be part of a larger matplotlib figure\n", " * - Easy figure-level customization\n", " - Different API from matplotlib\n", " * - Different figure size parameterization\n", " - Different figure size parameterization\n", "\n", "On balance, the figure-level functions add some additional complexity that can make things more confusing for beginners, but their distinct features give them additional power. The tutorial documentation mostly uses the figure-level functions, because they produce slightly cleaner plots, and we generally recommend their use for most applications. The one situation where they are not a good choice is when you need to make a complex, standalone figure that composes multiple different plot kinds. At this point, it's recommended to set up the figure using matplotlib directly and to fill in the individual components using axes-level functions." ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Combining multiple views on the data\n", "------------------------------------\n", "\n", "Two important plotting functions in seaborn don't fit cleanly into the classification scheme discussed above. These functions, :func:`jointplot` and :func:`pairplot`, employ multiple kinds of plots from different modules to represent multiple aspects of a dataset in a single figure. Both plots are figure-level functions and create figures with multiple subplots by default. But they use different objects to manage the figure: :class:`JointGrid` and :class:`PairGrid`, respectively.\n", "\n", ":func:`jointplot` plots the relationship or joint distribution of two variables while adding marginal axes that show the univariate distribution of each one separately:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ":func:`pairplot` is similar — it combines joint and marginal views — but rather than focusing on a single relationship, it visualizes every pairwise combination of variables simultaneously:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(data=penguins, hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Behind the scenes, these functions are using axes-level functions that you have already met (:func:`scatterplot` and :func:`kdeplot`), and they also have a ``kind`` parameter that lets you quickly swap in a different representation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\", kind=\"hist\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/introduction.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _introduction:\n", "\n", ".. currentmodule:: seaborn\n", "\n", "An introduction to seaborn\n", "==========================\n", "\n", "Seaborn is a library for making statistical graphics in Python. It builds on top of `matplotlib `_ and integrates closely with `pandas `_ data structures.\n", "\n", "Seaborn helps you explore and understand your data. Its plotting functions operate on dataframes and arrays containing whole datasets and internally perform the necessary semantic mapping and statistical aggregation to produce informative plots. Its dataset-oriented, declarative API lets you focus on what the different elements of your plots mean, rather than on the details of how to draw them.\n", "\n", "Here's an example of what seaborn can do:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Import seaborn\n", "import seaborn as sns\n", "\n", "# Apply the default theme\n", "sns.set_theme()\n", "\n", "# Load an example dataset\n", "tips = sns.load_dataset(\"tips\")\n", "\n", "# Create a visualization\n", "sns.relplot(\n", " data=tips,\n", " x=\"total_bill\", y=\"tip\", col=\"time\",\n", " hue=\"smoker\", style=\"smoker\", size=\"size\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A few things have happened here. Let's go through them one by one:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-output" ] }, "outputs": [], "source": [ "# Import seaborn\n", "import seaborn as sns" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Seaborn is the only library we need to import for this simple example. By convention, it is imported with the shorthand ``sns``.\n", "\n", "Behind the scenes, seaborn uses matplotlib to draw its plots. For interactive work, it's recommended to use a Jupyter/IPython interface in `matplotlib mode `_, or else you'll have to call :func:`matplotlib.pyplot.show` when you want to see the plot." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-output" ] }, "outputs": [], "source": [ "# Apply the default theme\n", "sns.set_theme()" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This uses the matplotlib rcParam system and will affect how all matplotlib plots look, even if you don't make them with seaborn. Beyond the default theme, there are :doc:`several other options `, and you can independently control the style and scaling of the plot to quickly translate your work between presentation contexts (e.g., making a version of your figure that will have readable fonts when projected during a talk). If you like the matplotlib defaults or prefer a different theme, you can skip this step and still use the seaborn plotting functions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-output" ] }, "outputs": [], "source": [ "# Load an example dataset\n", "tips = sns.load_dataset(\"tips\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Most code in the docs will use the :func:`load_dataset` function to get quick access to an example dataset. There's nothing special about these datasets: they are just pandas dataframes, and we could have loaded them with :func:`pandas.read_csv` or built them by hand. Most of the examples in the documentation will specify data using pandas dataframes, but seaborn is very flexible about the :doc:`data structures ` that it accepts." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide-output" ] }, "outputs": [], "source": [ "# Create a visualization\n", "sns.relplot(\n", " data=tips,\n", " x=\"total_bill\", y=\"tip\", col=\"time\",\n", " hue=\"smoker\", style=\"smoker\", size=\"size\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "This plot shows the relationship between five variables in the tips dataset using a single call to the seaborn function :func:`relplot`. Notice how we provided only the names of the variables and their roles in the plot. Unlike when using matplotlib directly, it wasn't necessary to specify attributes of the plot elements in terms of the color values or marker codes. Behind the scenes, seaborn handled the translation from values in the dataframe to arguments that matplotlib understands. This declarative approach lets you stay focused on the questions that you want to answer, rather than on the details of how to control matplotlib.\n", "\n", ".. _intro_api_abstraction:\n", "\n", "A high-level API for statistical graphics\n", "-----------------------------------------\n", "\n", "There is no universally best way to visualize data. Different questions are best answered by different plots. Seaborn makes it easy to switch between different visual representations by using a consistent dataset-oriented API.\n", "\n", "The function :func:`relplot` is named that way because it is designed to visualize many different statistical *relationships*. While scatter plots are often effective, relationships where one variable represents a measure of time are better represented by a line. The :func:`relplot` function has a convenient ``kind`` parameter that lets you easily switch to this alternate representation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dots = sns.load_dataset(\"dots\")\n", "sns.relplot(\n", " data=dots, kind=\"line\",\n", " x=\"time\", y=\"firing_rate\", col=\"align\",\n", " hue=\"choice\", size=\"coherence\", style=\"choice\",\n", " facet_kws=dict(sharex=False),\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Notice how the ``size`` and ``style`` parameters are used in both the scatter and line plots, but they affect the two visualizations differently: changing the marker area and symbol in the scatter plot vs the line width and dashing in the line plot. We did not need to keep those details in mind, letting us focus on the overall structure of the plot and the information we want it to convey.\n", "\n", ".. _intro_stat_estimation:\n", "\n", "Statistical estimation\n", "~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Often, we are interested in the *average* value of one variable as a function of other variables. Many seaborn functions will automatically perform the statistical estimation that is necessary to answer these questions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fmri = sns.load_dataset(\"fmri\")\n", "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", col=\"region\",\n", " hue=\"event\", style=\"event\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When statistical values are estimated, seaborn will use bootstrapping to compute confidence intervals and draw error bars representing the uncertainty of the estimate.\n", "\n", "Statistical estimation in seaborn goes beyond descriptive statistics. For example, it is possible to enhance a scatterplot by including a linear regression model (and its uncertainty) using :func:`lmplot`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(data=tips, x=\"total_bill\", y=\"tip\", col=\"time\", hue=\"smoker\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _intro_distributions:\n", "\n", "\n", "Distributional representations\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Statistical analyses require knowledge about the distribution of variables in your dataset. The seaborn function :func:`displot` supports several approaches to visualizing distributions. These include classic techniques like histograms and computationally-intensive approaches like kernel density estimation:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=tips, x=\"total_bill\", col=\"time\", kde=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Seaborn also tries to promote techniques that are powerful but less familiar, such as calculating and plotting the empirical cumulative distribution function of the data:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.displot(data=tips, kind=\"ecdf\", x=\"total_bill\", col=\"time\", hue=\"smoker\", rug=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _intro_categorical:\n", "\n", "Plots for categorical data\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Several specialized plot types in seaborn are oriented towards visualizing categorical data. They can be accessed through :func:`catplot`. These plots offer different levels of granularity. At the finest level, you may wish to see every observation by drawing a \"swarm\" plot: a scatter plot that adjusts the positions of the points along the categorical axis so that they don't overlap:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, kind=\"swarm\", x=\"day\", y=\"total_bill\", hue=\"smoker\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Alternately, you could use kernel density estimation to represent the underlying distribution that the points are sampled from:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, kind=\"violin\", x=\"day\", y=\"total_bill\", hue=\"smoker\", split=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Or you could show only the mean value and its confidence interval within each nested category:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.catplot(data=tips, kind=\"bar\", x=\"day\", y=\"total_bill\", hue=\"smoker\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _intro_dataset_funcs:\n", "\n", "Multivariate views on complex datasets\n", "--------------------------------------\n", "\n", "Some seaborn functions combine multiple kinds of plots to quickly give informative summaries of a dataset. One, :func:`jointplot`, focuses on a single relationship. It plots the joint distribution between two variables along with each variable's marginal distribution:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "penguins = sns.load_dataset(\"penguins\")\n", "sns.jointplot(data=penguins, x=\"flipper_length_mm\", y=\"bill_length_mm\", hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The other, :func:`pairplot`, takes a broader view: it shows joint and marginal distributions for all pairwise relationships and for each variable, respectively:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(data=penguins, hue=\"species\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _intro_figure_classes:\n", "\n", "Lower-level tools for building figures\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "These tools work by combining :doc:`axes-level ` plotting functions with objects that manage the layout of the figure, linking the structure of a dataset to a :doc:`grid of axes `. Both elements are part of the public API, and you can use them directly to create complex figures with only a few more lines of code:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "g = sns.PairGrid(penguins, hue=\"species\", corner=True)\n", "g.map_lower(sns.kdeplot, hue=None, levels=5, color=\".2\")\n", "g.map_lower(sns.scatterplot, marker=\"+\")\n", "g.map_diag(sns.histplot, element=\"step\", linewidth=0, kde=True)\n", "g.add_legend(frameon=True)\n", "g.legend.set_bbox_to_anchor((.61, .6))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _intro_defaults:\n", "\n", "Opinionated defaults and flexible customization\n", "-----------------------------------------------\n", "\n", "Seaborn creates complete graphics with a single function call: when possible, its functions will automatically add informative axis labels and legends that explain the semantic mappings in the plot.\n", "\n", "In many cases, seaborn will also choose default values for its parameters based on characteristics of the data. For example, the :doc:`color mappings ` that we have seen so far used distinct hues (blue, orange, and sometimes green) to represent different levels of the categorical variables assigned to ``hue``. When mapping a numeric variable, some functions will switch to a continuous gradient:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=penguins,\n", " x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"body_mass_g\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When you're ready to share or publish your work, you'll probably want to polish the figure beyond what the defaults achieve. Seaborn allows for several levels of customization. It defines multiple built-in :doc:`themes ` that apply to all figures, its functions have standardized parameters that can modify the semantic mappings for each plot, and additional keyword arguments are passed down to the underlying matplotlib artists, allowing even more control. Once you've created a plot, its properties can be modified through both the seaborn API and by dropping down to the matplotlib layer for fine-grained tweaking:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.set_theme(style=\"ticks\", font_scale=1.25)\n", "g = sns.relplot(\n", " data=penguins,\n", " x=\"bill_length_mm\", y=\"bill_depth_mm\", hue=\"body_mass_g\",\n", " palette=\"crest\", marker=\"x\", s=100,\n", ")\n", "g.set_axis_labels(\"Bill length (mm)\", \"Bill depth (mm)\", labelpad=10)\n", "g.legend.set_title(\"Body mass (g)\")\n", "g.figure.set_size_inches(6.5, 4.5)\n", "g.ax.margins(.15)\n", "g.despine(trim=True)" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _intro_matplotlib:\n", "\n", "Relationship to matplotlib\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Seaborn's integration with matplotlib allows you to use it across the many environments that matplotlib supports, including exploratory analysis in notebooks, real-time interaction in GUI applications, and archival output in a number of raster and vector formats.\n", "\n", "While you can be productive using only seaborn functions, full customization of your graphics will require some knowledge of matplotlib's concepts and API. One aspect of the learning curve for new users of seaborn will be knowing when dropping down to the matplotlib layer is necessary to achieve a particular customization. On the other hand, users coming from matplotlib will find that much of their knowledge transfers.\n", "\n", "Matplotlib has a comprehensive and powerful API; just about any attribute of the figure can be changed to your liking. A combination of seaborn's high-level interface and matplotlib's deep customizability will allow you both to quickly explore your data and to create graphics that can be tailored into a `publication quality `_ final product." ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _intro_next_steps:\n", "\n", "Next steps\n", "~~~~~~~~~~\n", "\n", "You have a few options for where to go next. You might first want to learn how to :doc:`install seaborn `. Once that's done, you can browse the :doc:`example gallery ` to get a broader sense for what kind of graphics seaborn can produce. Or you can read through the rest of the :doc:`user guide and tutorial ` for a deeper discussion of the different tools and what they are designed to accomplish. If you have a specific plot in mind and want to know how to make it, you could check out the :doc:`API reference `, which documents each function's parameters and shows many examples to illustrate usage." ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/objects_interface.ipynb ================================================ { "cells": [ { "cell_type": "raw", "id": "35110bb9-6889-4bd5-b9d6-5a0479131433", "metadata": {}, "source": [ ".. _objects_tutorial:\n", "\n", ".. currentmodule:: seaborn.objects\n", "\n", "The seaborn.objects interface\n", "=============================\n", "\n", "The `seaborn.objects` namespace was introduced in version 0.12 as a completely new interface for making seaborn plots. It offers a more consistent and flexible API, comprising a collection of composable classes for transforming and plotting data. In contrast to the existing `seaborn` functions, the new interface aims to support end-to-end plot specification and customization without dropping down to matplotlib (although it will remain possible to do so if necessary).\n", "\n", ".. note::\n", " The objects interface is currently experimental and incomplete. It is stable enough for serious use, but there certainly are some rough edges and missing features." ] }, { "cell_type": "code", "execution_count": null, "id": "706badfa-58be-4808-9016-bd0ca3ebaf12", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import seaborn as sns\n", "import matplotlib as mpl\n", "tips = sns.load_dataset(\"tips\")\n", "penguins = sns.load_dataset(\"penguins\").dropna()\n", "diamonds = sns.load_dataset(\"diamonds\")\n", "healthexp = sns.load_dataset(\"healthexp\").sort_values([\"Country\", \"Year\"]).query(\"Year <= 2020\")" ] }, { "cell_type": "raw", "id": "dd1ceae5-f930-41c2-8a18-f3cf94a161ad", "metadata": {}, "source": [ "Specifying a plot and mapping data\n", "----------------------------------\n", "\n", "The objects interface should be imported with the following convention:" ] }, { "cell_type": "code", "execution_count": null, "id": "1c113156-20ad-4612-a9f5-0071d7fd35dd", "metadata": {}, "outputs": [], "source": [ "import seaborn.objects as so" ] }, { "cell_type": "raw", "id": "6518484e-828b-4e7c-8529-ed6c9e61fa69", "metadata": {}, "source": [ "The `seaborn.objects` namespace will provide access to all of the relevant classes. The most important is :class:`Plot`. You specify plots by instantiating a :class:`Plot` object and calling its methods. Let's see a simple example:" ] }, { "cell_type": "code", "execution_count": null, "id": "2e7f8ad0-9831-464b-9825-60733f110f34", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", " .add(so.Dot())\n", ")" ] }, { "cell_type": "raw", "id": "52785052-6c80-4f35-87e4-b27df499bd5c", "metadata": {}, "source": [ "This code, which produces a scatter plot, should look reasonably familiar. Just as when using :func:`seaborn.scatterplot`, we passed a tidy dataframe (`penguins`) and assigned two of its columns to the `x` and `y` coordinates of the plot. But instead of starting with the type of chart and then adding some data assignments, here we started with the data assignments and then added a graphical element.\n", "\n", "Setting properties\n", "~~~~~~~~~~~~~~~~~~\n", "\n", "The :class:`Dot` class is an example of a :class:`Mark`: an object that graphically represents data values. Each mark will have a number of properties that can be set to change its appearance:" ] }, { "cell_type": "code", "execution_count": null, "id": "310bac42-cfe4-4c45-9ddf-27c2cb200a8a", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\")\n", " .add(so.Dot(color=\"g\", pointsize=4))\n", ")" ] }, { "cell_type": "raw", "id": "3f817822-dd96-4263-a42e-824f9ca4083a", "metadata": {}, "source": [ "Mapping properties\n", "~~~~~~~~~~~~~~~~~~\n", "\n", "As with seaborn's functions, it is also possible to *map* data values to various graphical properties:" ] }, { "cell_type": "code", "execution_count": null, "id": "6267e411-1f75-461e-a189-ead4452b2ec6", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(\n", " penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " color=\"species\", pointsize=\"body_mass_g\",\n", " )\n", " .add(so.Dot())\n", ")" ] }, { "cell_type": "raw", "id": "b6bfc0bf-cae1-44ed-9f52-e9f748c3877d", "metadata": {}, "source": [ "While this basic functionality is not novel, an important difference from the function API is that properties are mapped using the same parameter names that would set them directly (instead of having `hue` vs. `color`, etc.). What matters is *where* the property is defined: passing a value when you initialize :class:`Dot` will set it directly, whereas assigning a variable when you set up the :class:`Plot` will *map* the corresponding data.\n", "\n", "Beyond this difference, the objects interface also allows a much wider range of mark properties to be mapped:" ] }, { "cell_type": "code", "execution_count": null, "id": "b8637528-4e17-4a41-be1c-2cb4275a5586", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(\n", " penguins, x=\"bill_length_mm\", y=\"bill_depth_mm\",\n", " edgecolor=\"sex\", edgewidth=\"body_mass_g\",\n", " )\n", " .add(so.Dot(color=\".8\"))\n", ")" ] }, { "cell_type": "raw", "id": "220930c4-410c-4452-a89e-95045f325cc0", "metadata": {}, "source": [ "Defining groups\n", "~~~~~~~~~~~~~~~\n", "\n", "The :class:`Dot` mark represents each data point independently, so the assignment of a variable to a property only has the effect of changing each dot's appearance. For marks that group or connect observations, such as :class:`Line`, it also determines the number of distinct graphical elements:" ] }, { "cell_type": "code", "execution_count": null, "id": "95f892e1-8adc-43d3-8b30-84d8c848040a", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(healthexp, x=\"Year\", y=\"Life_Expectancy\", color=\"Country\")\n", " .add(so.Line())\n", ")" ] }, { "cell_type": "raw", "id": "6665552c-674b-405e-a3ee-237517649349", "metadata": {}, "source": [ "It is also possible to define a grouping without changing any visual properties, by using `group`:" ] }, { "cell_type": "code", "execution_count": null, "id": "f9287beb-7a66-4dcb-bccf-9c5cab2790f4", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(healthexp, x=\"Year\", y=\"Life_Expectancy\", group=\"Country\")\n", " .add(so.Line())\n", ")" ] }, { "cell_type": "raw", "id": "be097dfa-e33c-41f5-8b5a-09013cb33e6e", "metadata": {}, "source": [ "Transforming data before plotting\n", "---------------------------------\n", "\n", "Statistical transformation\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "As with many seaborn functions, the objects interface supports statistical transformations. These are performed by :class:`Stat` objects, such as :class:`Agg`:" ] }, { "cell_type": "code", "execution_count": null, "id": "0964d2af-ce53-48b5-b79a-3277b05584dd", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\", y=\"body_mass_g\")\n", " .add(so.Bar(), so.Agg())\n", ")" ] }, { "cell_type": "raw", "id": "5ac229b2-3692-4d35-8ba3-e35262f198ce", "metadata": {}, "source": [ "In the function interface, statistical transformations are possible with some visual representations (e.g. :func:`seaborn.barplot`) but not others (e.g. :func:`seaborn.scatterplot`). The objects interface more cleanly separates representation and transformation, allowing you to compose :class:`Mark` and :class:`Stat` objects:" ] }, { "cell_type": "code", "execution_count": null, "id": "5c2f917d-1cb7-4d33-b8c4-2126a4f91ccc", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\", y=\"body_mass_g\")\n", " .add(so.Dot(pointsize=10), so.Agg())\n", ")" ] }, { "cell_type": "raw", "id": "1b9d7688-22f5-4f4a-b58e-71d8ff550b48", "metadata": {}, "source": [ "When forming groups by mapping properties, the :class:`Stat` transformation is applied to each group separately:" ] }, { "cell_type": "code", "execution_count": null, "id": "734f9dac-4663-4e51-8070-716c0c0296c6", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\", y=\"body_mass_g\", color=\"sex\")\n", " .add(so.Dot(pointsize=10), so.Agg())\n", ")" ] }, { "cell_type": "raw", "id": "e60a8e83-c34c-4769-b34f-e0c23c80b870", "metadata": {}, "source": [ "Resolving overplotting\n", "~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Some seaborn functions also have mechanisms that automatically resolve overplotting, as when :func:`seaborn.barplot` \"dodges\" bars once `hue` is assigned. The objects interface has less complex default behavior. Bars representing multiple groups will overlap by default:" ] }, { "cell_type": "code", "execution_count": null, "id": "96653815-7da3-4a77-877a-485b5e7578a4", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\", y=\"body_mass_g\", color=\"sex\")\n", " .add(so.Bar(), so.Agg())\n", ")" ] }, { "cell_type": "raw", "id": "06ee3b9f-0ae9-467f-8a40-e340e6f3ce7d", "metadata": {}, "source": [ "Nevertheless, it is possible to compose the :class:`Bar` mark with the :class:`Agg` stat and a second transformation, implemented by :class:`Dodge`:" ] }, { "cell_type": "code", "execution_count": null, "id": "e29792ae-c238-4538-952a-5af81adcefe0", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\", y=\"body_mass_g\", color=\"sex\")\n", " .add(so.Bar(), so.Agg(), so.Dodge())\n", ")" ] }, { "cell_type": "raw", "id": "a27dcb37-be58-427b-a722-9039b91b6503", "metadata": {}, "source": [ "The :class:`Dodge` class is an example of a :class:`Move` transformation, which is like a :class:`Stat` but only adjusts `x` and `y` coordinates. The :class:`Move` classes can be applied with any mark, and it's not necessary to use a :class:`Stat` first:" ] }, { "cell_type": "code", "execution_count": null, "id": "c4509ea7-36fe-4ffb-b784-e945d13fb93c", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\", y=\"body_mass_g\", color=\"sex\")\n", " .add(so.Dot(), so.Dodge())\n", ")" ] }, { "cell_type": "raw", "id": "a62e44ae-d6e7-4ab5-af2e-7b49a2031b1d", "metadata": {}, "source": [ "It's also possible to apply multiple :class:`Move` operations in sequence:" ] }, { "cell_type": "code", "execution_count": null, "id": "07536818-9ddd-46d1-b10c-b034fa257335", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\", y=\"body_mass_g\", color=\"sex\")\n", " .add(so.Dot(), so.Dodge(), so.Jitter(.3))\n", ")" ] }, { "cell_type": "raw", "id": "fd8ed5cc-6ba4-4d03-8414-57a782971d4c", "metadata": {}, "source": [ "Creating variables through transformation\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "The :class:`Agg` stat requires both `x` and `y` to already be defined, but variables can also be *created* through statistical transformation. For example, the :class:`Hist` stat requires only one of `x` *or* `y` to be defined, and it will create the other by counting observations:" ] }, { "cell_type": "code", "execution_count": null, "id": "4b1f2c61-d294-4a85-a383-384d92523c36", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"species\")\n", " .add(so.Bar(), so.Hist())\n", ")" ] }, { "cell_type": "raw", "id": "9b33ea0c-f11d-48d7-be7c-13e9993906d8", "metadata": {}, "source": [ "The :class:`Hist` stat will also create new `x` values (by binning) when given numeric data:" ] }, { "cell_type": "code", "execution_count": null, "id": "25123abd-75d4-4550-ac86-5281fdabc023", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"flipper_length_mm\")\n", " .add(so.Bars(), so.Hist())\n", ")" ] }, { "cell_type": "raw", "id": "0dd84c56-eeb3-4904-b957-1677eaebd33c", "metadata": {}, "source": [ "Notice how we used :class:`Bars`, rather than :class:`Bar` for the plot with the continuous `x` axis. These two marks are related, but :class:`Bars` has different defaults and works better for continuous histograms. It also produces a different, more efficient matplotlib artist. You will find the pattern of singular/plural marks elsewhere. The plural version is typically optimized for cases with larger numbers of marks.\n", "\n", "Some transforms accept both `x` and `y`, but add *interval* data for each coordinate. This is particularly relevant for plotting error bars after aggregating:" ] }, { "cell_type": "code", "execution_count": null, "id": "6bc29e9d-d660-4638-80fd-8d77e15d9109", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"body_mass_g\", y=\"species\", color=\"sex\")\n", " .add(so.Range(), so.Est(errorbar=\"sd\"), so.Dodge())\n", " .add(so.Dot(), so.Agg(), so.Dodge())\n", ")" ] }, { "cell_type": "raw", "id": "3aecc891-1abb-45b2-bf15-c6944820b242", "metadata": {}, "source": [ "Orienting marks and transforms\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "When aggregating, dodging, and drawing a bar, the `x` and `y` variables are treated differently. Each operation has the concept of an *orientation*. The :class:`Plot` tries to determine the orientation automatically based on the data types of the variables. For instance, if we flip the assignment of `species` and `body_mass_g`, we'll get the same plot, but oriented horizontally:" ] }, { "cell_type": "code", "execution_count": null, "id": "1dd7ebeb-893e-4d27-aeaf-a8ff0cd2cc15", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"body_mass_g\", y=\"species\", color=\"sex\")\n", " .add(so.Bar(), so.Agg(), so.Dodge())\n", ")" ] }, { "cell_type": "raw", "id": "382603cb-9ae9-46ed-bceb-b48456781092", "metadata": {}, "source": [ "Sometimes, the correct orientation is ambiguous, as when both the `x` and `y` variables are numeric. In these cases, you can be explicit by passing the `orient` parameter to :meth:`Plot.add`:" ] }, { "cell_type": "code", "execution_count": null, "id": "75277dda-47c4-443c-9454-b8d97fc399e2", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"size\", color=\"time\")\n", " .add(so.Bar(), so.Agg(), so.Dodge(), orient=\"y\")\n", ")" ] }, { "cell_type": "raw", "id": "dc845c14-03e5-495d-9dc8-3a90f7879346", "metadata": {}, "source": [ "Building and displaying the plot\n", "--------------------------------\n", "\n", "Most examples this far have produced a single subplot with just one kind of mark on it. But :class:`Plot` does not limit you to this.\n", "\n", "Adding multiple layers\n", "~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "More complex single-subplot graphics can be created by calling :meth:`Plot.add` repeatedly. Each time it is called, it defines a *layer* in the plot. For example, we may want to add a scatterplot (now using :class:`Dots`) and then a regression fit:" ] }, { "cell_type": "code", "execution_count": null, "id": "922b6d3d-7a81-4921-97f2-953a1fbc69ec", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", " .add(so.Dots())\n", " .add(so.Line(), so.PolyFit())\n", ")" ] }, { "cell_type": "raw", "id": "f0309733-a86a-4952-bc3b-533d639f0b52", "metadata": {}, "source": [ "Variable mappings that are defined in the :class:`Plot` constructor will be used for all layers:" ] }, { "cell_type": "code", "execution_count": null, "id": "604d16b9-383b-4b88-9ed7-fdefed55039a", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"time\")\n", " .add(so.Dots())\n", " .add(so.Line(), so.PolyFit())\n", ")" ] }, { "cell_type": "raw", "id": "eb56fb8d-aaa3-4b6e-b311-0354562174b5", "metadata": {}, "source": [ "Layer-specific mappings\n", "~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "You can also define a mapping such that it is used only in a specific layer. This is accomplished by defining the mapping within the call to :class:`Plot.add` for the relevant layer:" ] }, { "cell_type": "code", "execution_count": null, "id": "f69a3a38-97e8-40fb-b7d4-95a751ebdcfb", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"tip\")\n", " .add(so.Dots(), color=\"time\")\n", " .add(so.Line(color=\".2\"), so.PolyFit())\n", ")" ] }, { "cell_type": "raw", "id": "b3f94f01-23d4-4f7a-98f8-de93dafc230a", "metadata": {}, "source": [ "Alternatively, define the layer for the entire plot, but *remove* it from a specific layer by setting the variable to `None`:" ] }, { "cell_type": "code", "execution_count": null, "id": "45706bec-3453-4a7e-9ac7-c743baff4da6", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(tips, x=\"total_bill\", y=\"tip\", color=\"time\")\n", " .add(so.Dots())\n", " .add(so.Line(color=\".2\"), so.PolyFit(), color=None)\n", ")" ] }, { "cell_type": "raw", "id": "295013b3-7d91-4a59-b63b-fe50e642954c", "metadata": {}, "source": [ "To recap, there are three ways to specify the value of a mark property: (1) by mapping a variable in all layers, (2) by mapping a variable in a specific layer, and (3) by setting the property directly:" ] }, { "cell_type": "code", "execution_count": null, "id": "2341eafd-4d6f-4530-835a-a409d2057d74", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "from io import StringIO\n", "from IPython.display import SVG\n", "C = sns.color_palette(\"deep\")\n", "f = mpl.figure.Figure(figsize=(7, 3))\n", "ax = f.subplots()\n", "fontsize = 18\n", "ax.add_artist(mpl.patches.Rectangle((.13, .53), .45, .09, color=C[0], alpha=.3))\n", "ax.add_artist(mpl.patches.Rectangle((.22, .43), .235, .09, color=C[1], alpha=.3))\n", "ax.add_artist(mpl.patches.Rectangle((.49, .43), .26, .09, color=C[2], alpha=.3))\n", "ax.text(.05, .55, \"Plot(data, 'x', 'y', color='var1')\", size=fontsize, color=\".2\")\n", "ax.text(.05, .45, \".add(Dot(pointsize=10), marker='var2')\", size=fontsize, color=\".2\")\n", "annots = [\n", " (\"Mapped\\nin all layers\", (.35, .65), (0, 45)),\n", " (\"Set directly\", (.35, .4), (0, -45)),\n", " (\"Mapped\\nin this layer\", (.63, .4), (0, -45)),\n", "]\n", "for i, (text, xy, xytext) in enumerate(annots):\n", " ax.annotate(\n", " text, xy, xytext,\n", " textcoords=\"offset points\", fontsize=14, ha=\"center\", va=\"center\",\n", " arrowprops=dict(arrowstyle=\"->\", color=C[i]), color=C[i],\n", " )\n", "ax.set_axis_off()\n", "f.subplots_adjust(0, 0, 1, 1)\n", "f.savefig(s:=StringIO(), format=\"svg\")\n", "SVG(s.getvalue())" ] }, { "cell_type": "raw", "id": "cf2d8e39-d332-41f4-b327-2ac352878e58", "metadata": {}, "source": [ "Faceting and pairing subplots\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "As with seaborn's figure-level functions (:func:`seaborn.displot`, :func:`seaborn.catplot`, etc.), the :class:`Plot` interface can also produce figures with multiple \"facets\", or subplots containing subsets of data. This is accomplished with the :meth:`Plot.facet` method:" ] }, { "cell_type": "code", "execution_count": null, "id": "af737dfd-1cb2-418d-9f52-1deb93154a92", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"flipper_length_mm\")\n", " .facet(\"species\")\n", " .add(so.Bars(), so.Hist())\n", ")" ] }, { "cell_type": "raw", "id": "81c2a445-5ae1-4272-8a6c-8bfe1f3b907f", "metadata": {}, "source": [ "Call :meth:`Plot.facet` with the variables that should be used to define the columns and/or rows of the plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "b7b3495f-9a38-4976-b718-ce3672b8c186", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"flipper_length_mm\")\n", " .facet(col=\"species\", row=\"sex\")\n", " .add(so.Bars(), so.Hist())\n", ")" ] }, { "cell_type": "raw", "id": "8b7fe085-acd2-46d2-81f6-a806dec338d3", "metadata": {}, "source": [ "You can facet using a variable with a larger number of levels by \"wrapping\" across the other dimension:" ] }, { "cell_type": "code", "execution_count": null, "id": "d62d2310-ae33-4b42-bdea-7b7456afd640", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(healthexp, x=\"Year\", y=\"Life_Expectancy\")\n", " .facet(col=\"Country\", wrap=3)\n", " .add(so.Line())\n", ")" ] }, { "cell_type": "markdown", "id": "86ecbeee-3ac2-41eb-b79e-9d6ed026061d", "metadata": {}, "source": [ "All layers will be faceted unless you explicitly exclude them, which can be useful for providing additional context on each subplot:" ] }, { "cell_type": "code", "execution_count": null, "id": "c38be724-8564-4fa0-861c-1d96ffbbda20", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(healthexp, x=\"Year\", y=\"Life_Expectancy\")\n", " .facet(\"Country\", wrap=3)\n", " .add(so.Line(alpha=.3), group=\"Country\", col=None)\n", " .add(so.Line(linewidth=3))\n", ")" ] }, { "cell_type": "raw", "id": "f97dad75-65e6-47fd-9fc4-08a8f2cb49ee", "metadata": {}, "source": [ "An alternate way to produce subplots is :meth:`Plot.pair`. Like :class:`seaborn.PairGrid`, this draws all of the data on each subplot, using different variables for the x and/or y coordinates:" ] }, { "cell_type": "code", "execution_count": null, "id": "d6350e99-2c70-4a96-87eb-74756a0fa335", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, y=\"body_mass_g\", color=\"species\")\n", " .pair(x=[\"bill_length_mm\", \"bill_depth_mm\"])\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "raw", "id": "4deea650-b4b9-46ea-876c-2e5a3a258649", "metadata": {}, "source": [ "You can combine faceting and pairing so long as the operations add subplots on opposite dimensions:" ] }, { "cell_type": "code", "execution_count": null, "id": "9de7948c-4c43-4116-956c-cbcb84d8652c", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, y=\"body_mass_g\", color=\"species\")\n", " .pair(x=[\"bill_length_mm\", \"bill_depth_mm\"])\n", " .facet(row=\"sex\")\n", " .add(so.Dots())\n", ")" ] }, { "cell_type": "raw", "id": "0a0febe3-9daf-4271-aef9-9637d59aaf10", "metadata": {}, "source": [ "Integrating with matplotlib\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "There may be cases where you want multiple subplots to appear in a figure with a more complex structure than what :meth:`Plot.facet` or :meth:`Plot.pair` can provide. The current solution is to delegate figure setup to matplotlib and to supply the matplotlib object that :class:`Plot` should use with the :meth:`Plot.on` method. This object can be either a :class:`matplotlib.axes.Axes`, :class:`matplotlib.figure.Figure`, or :class:`matplotlib.figure.SubFigure`; the latter is most useful for constructing bespoke subplot layouts:" ] }, { "cell_type": "code", "execution_count": null, "id": "b046466d-f6c2-43fa-9ae9-f40a292a82b7", "metadata": {}, "outputs": [], "source": [ "f = mpl.figure.Figure(figsize=(8, 4))\n", "sf1, sf2 = f.subfigures(1, 2)\n", "(\n", " so.Plot(penguins, x=\"body_mass_g\", y=\"flipper_length_mm\")\n", " .add(so.Dots())\n", " .on(sf1)\n", " .plot()\n", ")\n", "(\n", " so.Plot(penguins, x=\"body_mass_g\")\n", " .facet(row=\"sex\")\n", " .add(so.Bars(), so.Hist())\n", " .on(sf2)\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "7074f599-8b9f-4b77-9e15-55349592c747", "metadata": {}, "source": [ "Building and displaying the plot\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "An important thing to know is that :class:`Plot` methods clone the object they are called on and return that clone instead of updating the object in place. This means that you can define a common plot spec and then produce several variations on it.\n", "\n", "So, take this basic specification:" ] }, { "cell_type": "code", "execution_count": null, "id": "b79b2148-b867-4e96-9b84-b3fc44ad0c82", "metadata": {}, "outputs": [], "source": [ "p = so.Plot(healthexp, \"Year\", \"Spending_USD\", color=\"Country\")" ] }, { "cell_type": "raw", "id": "135f89e5-c41e-4c6c-9865-5413787bdc62", "metadata": {}, "source": [ "We could use it to draw a line plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "10722a20-dc8c-4421-a433-8ff21fed9495", "metadata": {}, "outputs": [], "source": [ "p.add(so.Line())" ] }, { "cell_type": "raw", "id": "f9db1184-f352-41b8-a45a-02ff6eb85071", "metadata": {}, "source": [ "Or perhaps a stacked area plot:" ] }, { "cell_type": "code", "execution_count": null, "id": "ea2ad629-c718-44a9-92af-144728094cd5", "metadata": {}, "outputs": [], "source": [ "p.add(so.Area(), so.Stack())" ] }, { "cell_type": "raw", "id": "17fb2676-6199-4a2c-9f10-3d5aebb7a285", "metadata": {}, "source": [ "The :class:`Plot` methods are fully declarative. Calling them updates the plot spec, but it doesn't actually do any plotting. One consequence of this is that methods can be called in any order, and many of them can be called multiple times.\n", "\n", "When does the plot actually get rendered? :class:`Plot` is optimized for use in notebook environments. The rendering is automatically triggered when the :class:`Plot` gets displayed in the Jupyter REPL. That's why we didn't see anything in the example above, where we defined a :class:`Plot` but assigned it to `p` rather than letting it return out to the REPL.\n", "\n", "To see a plot in a notebook, either return it from the final line of a cell or call Jupyter's built-in `display` function on the object. The notebook integration bypasses :mod:`matplotlib.pyplot` entirely, but you can use its figure-display machinery in other contexts by calling :meth:`Plot.show`.\n", "\n", "You can also save the plot to a file (or buffer) by calling :meth:`Plot.save`." ] }, { "cell_type": "raw", "id": "abfa0384-af88-4409-a119-912601a14f13", "metadata": {}, "source": [ "Customizing the appearance\n", "--------------------------\n", "\n", "The new interface aims to support a deep amount of customization through :class:`Plot`, reducing the need to switch gears and use matplotlib functionality directly. (But please be patient; not all of the features needed to achieve this goal have been implemented!)\n", "\n", "Parameterizing scales\n", "~~~~~~~~~~~~~~~~~~~~~\n", "\n", "All of the data-dependent properties are controlled by the concept of a :class:`Scale` and the :meth:`Plot.scale` method. This method accepts several different types of arguments. One possibility, which is closest to the use of scales in matplotlib, is to pass the name of a function that transforms the coordinates:" ] }, { "cell_type": "code", "execution_count": null, "id": "5acfe6d2-144a-462d-965b-2900fb619eac", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(diamonds, x=\"carat\", y=\"price\")\n", " .add(so.Dots())\n", " .scale(y=\"log\")\n", ")" ] }, { "cell_type": "raw", "id": "ccff884b-53cb-4c15-aab2-f5d4e5551d72", "metadata": {}, "source": [ ":meth:`Plot.scale` can also control the mappings for semantic properties like `color`. You can directly pass it any argument that you would pass to the `palette` parameter in seaborn's function interface:" ] }, { "cell_type": "code", "execution_count": null, "id": "4f243a31-d7da-43d2-8dc4-aad1b584ff48", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(diamonds, x=\"carat\", y=\"price\", color=\"clarity\")\n", " .add(so.Dots())\n", " .scale(color=\"flare\")\n", ")" ] }, { "cell_type": "raw", "id": "4fdf291e-a008-4a8e-8ced-a24f78d9b49f", "metadata": {}, "source": [ "Another option is to provide a tuple of `(min, max)` values, controlling the range that the scale should map into. This works both for numeric properties and for colors:" ] }, { "cell_type": "code", "execution_count": null, "id": "4cdc12ee-83f9-4472-b198-85bfe5cf0e4f", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(diamonds, x=\"carat\", y=\"price\", color=\"clarity\", pointsize=\"carat\")\n", " .add(so.Dots())\n", " .scale(color=(\"#88c\", \"#555\"), pointsize=(2, 10))\n", ")" ] }, { "cell_type": "raw", "id": "e326bf46-a296-4997-8e91-6531a7eef304", "metadata": {}, "source": [ "For additional control, you can pass a :class:`Scale` object. There are several different types of :class:`Scale`, each with appropriate parameters. For example, :class:`Continuous` lets you define the input domain (`norm`), the output range (`values`), and the function that maps between them (`trans`), while :class:`Nominal` allows you to specify an ordering:" ] }, { "cell_type": "code", "execution_count": null, "id": "53682db4-2ba4-4dfd-80c2-1fef466cfab2", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(diamonds, x=\"carat\", y=\"price\", color=\"carat\", marker=\"cut\")\n", " .add(so.Dots())\n", " .scale(\n", " color=so.Continuous(\"crest\", norm=(0, 3), trans=\"sqrt\"),\n", " marker=so.Nominal([\"o\", \"+\", \"x\"], order=[\"Ideal\", \"Premium\", \"Good\"]),\n", " )\n", ")" ] }, { "cell_type": "raw", "id": "7bf112fe-136d-4e63-a397-1e7d2ff4f543", "metadata": {}, "source": [ "Customizing legends and ticks\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "The :class:`Scale` objects are also how you specify which values should appear as tick labels / in the legend, along with how they appear. For example, the :meth:`Continuous.tick` method lets you control the density or locations of the ticks, and the :meth:`Continuous.label` method lets you modify the format:" ] }, { "cell_type": "code", "execution_count": null, "id": "4f8e821f-bd19-4af1-bb66-488593b3c968", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(diamonds, x=\"carat\", y=\"price\", color=\"carat\")\n", " .add(so.Dots())\n", " .scale(\n", " x=so.Continuous().tick(every=0.5),\n", " y=so.Continuous().label(like=\"${x:.0f}\"),\n", " color=so.Continuous().tick(at=[1, 2, 3, 4]),\n", " )\n", ")" ] }, { "cell_type": "raw", "id": "4f6646c9-084b-49ae-ad6f-39c0bd12fc4e", "metadata": {}, "source": [ "Customizing limits, labels, and titles\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", ":class:`Plot` has a number of methods for simple customization, including :meth:`Plot.label`, :meth:`Plot.limit`, and :meth:`Plot.share`:" ] }, { "cell_type": "code", "execution_count": null, "id": "e9586669-35ea-4784-9594-ea375a06aec0", "metadata": {}, "outputs": [], "source": [ "(\n", " so.Plot(penguins, x=\"body_mass_g\", y=\"species\", color=\"island\")\n", " .facet(col=\"sex\")\n", " .add(so.Dot(), so.Jitter(.5))\n", " .share(x=False)\n", " .limit(y=(2.5, -.5))\n", " .label(\n", " x=\"Body mass (g)\", y=\"\",\n", " color=str.capitalize,\n", " title=\"{} penguins\".format,\n", " )\n", ")" ] }, { "cell_type": "raw", "id": "3b38607a-9b41-49c0-8031-e05bc87701c8", "metadata": {}, "source": [ "Theme customization\n", "~~~~~~~~~~~~~~~~~~~\n", "\n", "Finally, :class:`Plot` supports data-independent theming through the :class:`Plot.theme` method. Currently, this method accepts a dictionary of matplotlib rc parameters. You can set them directly and/or pass a package of parameters from seaborn's theming functions:" ] }, { "cell_type": "code", "execution_count": null, "id": "2df40831-fd41-4b76-90ff-042aecd694d4", "metadata": {}, "outputs": [], "source": [ "from seaborn import axes_style\n", "theme_dict = {**axes_style(\"whitegrid\"), \"grid.linestyle\": \":\"}\n", "so.Plot().theme(theme_dict)" ] }, { "cell_type": "raw", "id": "475d5157-5e88-473e-991f-528219ed3744", "metadata": {}, "source": [ "To change the theme for all :class:`Plot` instances, update the settings in :attr:`Plot.config`:" ] }, { "cell_type": "code", "execution_count": null, "id": "41ac347c-766f-495c-8a7f-43fee8cad29a", "metadata": {}, "outputs": [], "source": [ "so.Plot.config.theme.update(theme_dict)" ] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_tutorial/properties.ipynb ================================================ { "cells": [ { "cell_type": "raw", "id": "6cb222bb-4781-48b6-9675-c0ba195b5efb", "metadata": {}, "source": [ ".. _properties_tutorial:\n", "\n", "Properties of Mark objects\n", "===========================" ] }, { "cell_type": "code", "execution_count": null, "id": "ae9d52dc-55ad-4804-a533-f2b724d0b85b", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib as mpl\n", "import seaborn.objects as so\n", "from seaborn import axes_style, color_palette" ] }, { "cell_type": "raw", "id": "dd828c60-3895-46e4-a2f4-782a6e6cd9a6", "metadata": {}, "source": [ "Coordinate properties\n", "---------------------" ] }, { "cell_type": "raw", "id": "fa97cc40-f02f-477b-90ec-a764b7253b68", "metadata": {}, "source": [ ".. _coordinate_property:\n", "\n", "x, y, xmin, xmax, ymin, ymax\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Coordinate properties determine where a mark is drawn on a plot. Canonically, the `x` coordinate is the horizontal position and the `y` coordinate is the vertical position. Some marks accept a span (i.e., `min`, `max`) parameterization for one or both variables. Others may accept `x` and `y` but also use a `baseline` parameter to show a span. The layer's `orient` parameter determines how this works.\n", "\n", "If a variable does not contain numeric data, its scale will apply a conversion so that data can be drawn on a screen. For instance, :class:`Nominal` scales assign an integer index to each distinct category, and :class:`Temporal` scales represent dates as the number of days from a reference \"epoch\":" ] }, { "cell_type": "code", "execution_count": null, "id": "7b418365-b99c-45d6-bf1e-e347e2b9012a", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "(\n", " so.Plot(y=[0, 0, 0])\n", " .pair(x=[\n", " [1, 2, 3],\n", " [\"A\", \"B\", \"C\"],\n", " np.array([\"2020-01-01\", \"2020-02-01\", \"2020-03-01\"], dtype=\"datetime64\"),\n", " ])\n", " .limit(\n", " x0=(0, 10),\n", " x1=(-.5, 2.5),\n", " x2=(pd.Timestamp(\"2020-01-01\"), pd.Timestamp(\"2020-03-01\"))\n", " )\n", " .scale(y=so.Continuous().tick(count=0), x2=so.Temporal().label(concise=True))\n", " .layout(size=(7, 1), engine=\"tight\")\n", " .label(x0=\"Continuous\", x1=\"Nominal\", x2=\"Temporal\")\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"left\", \"right\", \"top\"]},\n", " })\n", ")" ] }, { "cell_type": "raw", "id": "0ae06665-2ce5-470d-b90a-02d990221fc5", "metadata": {}, "source": [ "A :class:`Continuous` scale can also apply a nonlinear transform between data values and spatial positions:" ] }, { "cell_type": "code", "execution_count": null, "id": "b731a3bb-a52e-4b12-afbb-b036753adcbe", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "(\n", " so.Plot(y=[0, 0, 0])\n", " .pair(x=[[1, 10, 100], [-100, 0, 100], [0, 10, 40]])\n", " .limit(\n", " )\n", " .add(so.Dot(marker=\"\"))\n", " .scale(\n", " y=so.Continuous().tick(count=0),\n", " x0=so.Continuous(trans=\"log\"),\n", " x1=so.Continuous(trans=\"symlog\").tick(at=[-100, -10, 0, 10, 100]),\n", " x2=so.Continuous(trans=\"sqrt\").tick(every=10),\n", " )\n", " .layout(size=(7, 1), engine=\"tight\")\n", " .label(x0=\"trans='log'\", x1=\"trans='symlog'\", x2=\"trans='sqrt'\")\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"left\", \"right\", \"top\"]},\n", " \"axes.labelpad\": 8,\n", " })\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "e384941a-da38-4e12-997d-d750b19b1fa6", "metadata": { "tags": [ "hide-input", "hide" ] }, "outputs": [], "source": [ "# Hiding from the page but keeping around for now\n", "(\n", " so.Plot()\n", " .add(\n", " so.Dot(edgewidth=3, stroke=3),\n", " so.Dodge(by=[\"group\"]),\n", " x=[\"A\", \"A\", \"A\", \"A\", \"A\"],\n", " y=[1.75, 2.25, 2.75, 2.0, 2.5],\n", " color=[1, 2, 3, 1, 3],\n", " marker=[mpl.markers.MarkerStyle(x) for x in \"os^+o\"],\n", " pointsize=(9, 9, 9, 13, 10),\n", " fill=[True, False, True, True, False],\n", " group=[1, 2, 3, 4, 5], width=.5, legend=False,\n", " )\n", " .add(\n", " so.Bar(edgewidth=2.5, alpha=.2, width=.9),\n", " so.Dodge(gap=.05),\n", " x=[\"B\", \"B\", \"B\",], y=[2, 2.5, 1.75], color=[1, 2, 3],\n", " legend=False,\n", " )\n", " .add(\n", " so.Range({\"capstyle\": \"round\"}, linewidth=3),\n", " so.Dodge(by=[\"group\"]),\n", " x=[\"C\", \"C\", \"C\"], ymin=[1.5, 1.75, 1.25], ymax=[2.5, 2.75, 2.25],\n", " color=[1, 2, 2], linestyle=[\"-\", \"-\", \":\"],\n", " group=[1, 2, 3], width=.5, legend=False,\n", " )\n", " .layout(size=(4, 4), engine=None)\n", " .limit(x=(-.5, 2.5), y=(0, 3))\n", " .label(x=\"X Axis (nominal)\", y=\"Y Axis (continuous)\")\n", " .scale(\n", " color=\"dark:C0_r\", #None,\n", " fill=None, marker=None,\n", " pointsize=None, linestyle=None,\n", " y=so.Continuous().tick(every=1, minor=1)\n", " )\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " \"axes.spines.top\": False, \"axes.spines.right\": False,\n", " \"axes.labelsize\": 14,\n", " })\n", ")" ] }, { "cell_type": "raw", "id": "8279d74f-0cd0-4ba8-80ed-c6051541d956", "metadata": {}, "source": [ "Color properties\n", "----------------" ] }, { "cell_type": "raw", "id": "fca25527-6bbe-42d6-beea-a996a46d9761", "metadata": {}, "source": [ ".. _color_property:\n", "\n", "color, fillcolor, edgecolor\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "All marks can be given a `color`, and many distinguish between the color of the mark's \"edge\" and \"fill\". Often, simply using `color` will set both, while the more-specific properties allow further control:" ] }, { "cell_type": "code", "execution_count": null, "id": "ff7a1e64-7b02-45b8-b1e7-d7ec2bf1e7f7", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "no_spines = {\n", " f\"axes.spines.{side}\": False\n", " for side in [\"left\", \"right\", \"bottom\", \"top\"]\n", "}" ] }, { "cell_type": "code", "execution_count": null, "id": "1dda4c42-31f4-4316-baad-f30a465d3fd9", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "color_mark = so.Dot(marker=\"s\", pointsize=20, edgewidth=2.5, alpha=.7, edgealpha=1)\n", "color_plot = (\n", " so.Plot()\n", " .theme({\n", " **axes_style(\"white\"),\n", " **no_spines,\n", " \"axes.titlesize\": 15,\n", " \"figure.subplot.wspace\": .1,\n", " \"axes.xmargin\": .1,\n", " })\n", " .scale(\n", " x=so.Continuous().tick(count=0),\n", " y=so.Continuous().tick(count=0),\n", " color=None, edgecolor=None,\n", " )\n", " .layout(size=(9, .5), engine=None)\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "54fc98b4-dc4c-45e1-a2a7-840a724fc746", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "n = 6\n", "rgb = [f\"C{i}\" for i in range(n)]\n", "(\n", " color_plot\n", " .facet([\"color\"] * n + [\"edgecolor\"] * n + [\"fillcolor\"] * n)\n", " .add(\n", " color_mark,\n", " x=np.tile(np.arange(n), 3),\n", " y=np.zeros(n * 3),\n", " color=rgb + [\".8\"] * n + rgb,\n", " edgecolor=rgb + rgb + [\".3\"] * n,\n", " legend=False,\n", " )\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "0dc26a01-6290-44f4-9815-5cea531207e2", "metadata": {}, "source": [ "When the color property is mapped, the default palette depends on the type of scale. Nominal scales use discrete, unordered hues, while continuous scales (including temporal ones) use a sequential gradient:" ] }, { "cell_type": "code", "execution_count": null, "id": "6927a0d3-687b-4ca0-a425-0376b39f1b1f", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "n = 9\n", "rgb = color_palette(\"deep\", n) + color_palette(\"ch:\", n)\n", "(\n", " color_plot\n", " .facet([\"nominal\"] * n + [\"continuous\"] * n)\n", " .add(\n", " color_mark,\n", " x=list(range(n)) * 2,\n", " y=[0] * n * 2,\n", " color=rgb,\n", " legend=False,\n", " )\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "e79d0da7-a53e-468c-9952-726eeae810d1", "metadata": {}, "source": [ ".. note::\n", " The default continuous scale is subject to change in future releases to improve discriminability.\n", "\n", "Color scales are parameterized by the name of a palette, such as `'viridis'`, `'rocket'`, or `'deep'`. Some palette names can include parameters, including simple gradients (e.g. `'dark:blue'`) or the cubehelix system (e.g. `'ch:start=.2,rot=-.4``). See the :doc:`color palette tutorial ` for guidance on making an appropriate selection.\n", "\n", "Continuous scales can also be parameterized by a tuple of colors that the scale should interpolate between. When using a nominal scale, it is possible to provide either the name of the palette (which will be discretely-sampled, if necessary), a list of individual color values, or a dictionary directly mapping data values to colors.\n", "\n", "Individual colors may be specified `in a wide range of formats `_. These include indexed references to the current color cycle (`'C0'`), single-letter shorthands (`'b'`), grayscale values (`'.4'`), RGB hex codes (`'#4c72b0'`), X11 color names (`'seagreen'`), and XKCD color survey names (`'purpleish'`):" ] }, { "cell_type": "code", "execution_count": null, "id": "ce7300dc-0ed2-4eb3-bd6f-2e42280f5e54", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "color_dict = {\n", " \"cycle\": [\"C0\", \"C1\", \"C2\"],\n", " \"short\": [\"r\", \"y\", \"b\"],\n", " \"gray\": [\".3\", \".7\", \".5\"],\n", " \"hex\": [\"#825f87\", \"#05696b\", \"#de7e5d\"],\n", " \"X11\": [\"seagreen\", \"sienna\", \"darkblue\"],\n", " \"XKCD\": [\"xkcd:gold\", \"xkcd:steel\", \"xkcd:plum\"],\n", "}\n", "groups = [k for k in color_dict for _ in range(3)]\n", "colors = [c for pal in color_dict.values() for c in pal]\n", "(\n", " so.Plot(\n", " x=[0] * len(colors),\n", " y=[f\"'{c}'\" for c in colors],\n", " color=colors,\n", " )\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **no_spines,\n", " \"axes.ymargin\": .2,\n", " \"axes.titlesize\": 14,\n", " \n", " })\n", " .facet(groups)\n", " .layout(size=(8, 1.15), engine=\"constrained\")\n", " .scale(x=so.Continuous().tick(count=0))\n", " .add(color_mark)\n", " .limit(x=(-.2, .5))\n", " # .label(title=\"{} \".format)\n", " .label(title=\"\")\n", " .scale(color=None)\n", " .share(y=False)\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "4ea6ac35-2a73-4dec-8b9b-bf15ba67f01b", "metadata": {}, "source": [ ".. _alpha_property:\n", "\n", "alpha, fillalpha, edgealpha\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "The `alpha` property determines the mark's opacity. Lowering the alpha can be helpful for representing density in the case of overplotting:" ] }, { "cell_type": "code", "execution_count": null, "id": "e73839d2-27c4-42b8-8587-9f6e99c8a464", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "rng = np.random.default_rng(3)\n", "n_samp = 300\n", "x = 1 - rng.exponential(size=n_samp)\n", "y = rng.uniform(-1, 1, size=n_samp)\n", "keep = np.sqrt(x ** 2 + y ** 2) < 1\n", "x, y = x[keep], y[keep]\n", "n = keep.sum()\n", "alpha_vals = np.linspace(.1, .9, 9).round(1)\n", "xs = np.concatenate([x for _ in alpha_vals])\n", "ys = np.concatenate([y for _ in alpha_vals])\n", "alphas = np.repeat(alpha_vals, n)\n", "(\n", " so.Plot(x=xs, y=ys, alpha=alphas)\n", " .facet(alphas)\n", " .add(so.Dot(color=\".2\", pointsize=3))\n", " .scale(\n", " alpha=None,\n", " x=so.Continuous().tick(count=0),\n", " y=so.Continuous().tick(count=0)\n", " )\n", " .layout(size=(9, 1), engine=None)\n", " .theme({\n", " **axes_style(\"white\"),\n", " **no_spines,\n", " })\n", ")" ] }, { "cell_type": "raw", "id": "a551732e-e8f5-45f0-9345-7ef45248d9d7", "metadata": {}, "source": [ "Mapping the `alpha` property can also be useful even when marks do not overlap because it conveys a sense of importance and can be combined with a `color` scale to represent two variables. Moreover, colors with lower alpha appear less saturated, which can improve the appearance of larger filled marks (such as bars).\n", "\n", "As with `color`, some marks define separate `edgealpha` and `fillalpha` properties for additional control." ] }, { "cell_type": "raw", "id": "77d168e4-0539-409f-8542-750d3981e22b", "metadata": {}, "source": [ "Style properties\n", "----------------" ] }, { "cell_type": "raw", "id": "95e342fa-1086-4e63-81ae-dce1c628df9b", "metadata": {}, "source": [ ".. _fill_property:\n", "\n", "fill\n", "~~~~\n", "\n", "The `fill` property is relevant to marks with a distinction between the edge and interior and determines whether the interior is visible. It is a boolean state: `fill` can be set only to `True` or `False`:" ] }, { "cell_type": "code", "execution_count": null, "id": "5fb3b839-8bae-4392-b5f0-70dfc5a33c7a", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "nan = float(\"nan\")\n", "x_bar = [0, 1]\n", "y_bar = [2, 1]\n", "f_bar = [True, False]\n", "\n", "x_dot = [2.2, 2.5, 2.8, 3.2, 3.5, 3.8]\n", "y_dot = [1.2, 1.7, 1.4, 0.7, 1.2, 0.9]\n", "f_dot = [True, True, True, False, False, False]\n", "\n", "xx = np.linspace(0, .8, 100)\n", "yy = xx ** 2 * np.exp(-xx * 10)\n", "x_area = list(4.5 + xx) + list(5.5 + xx)\n", "y_area = list(yy / yy.max() * 2) + list(yy / yy.max())\n", "f_area = [True] * 100 + [False] * 100\n", "\n", "(\n", " so.Plot()\n", " .add(\n", " so.Bar(color=\".3\", edgecolor=\".2\", edgewidth=2.5),\n", " x=x_bar + [nan for _ in x_dot + x_area],\n", " y=y_bar + [nan for _ in y_dot + y_area],\n", " fill=f_bar + [nan for _ in f_dot + f_area]\n", " )\n", " .add(\n", " so.Dot(color=\".2\", pointsize=13, stroke=2.5),\n", " x=[nan for _ in x_bar] + x_dot + [nan for _ in x_area],\n", " y=[nan for _ in y_bar] + y_dot + [nan for _ in y_area],\n", " fill=[nan for _ in f_bar] + f_dot + [nan for _ in f_area],\n", " )\n", " .add(\n", " so.Area(color=\".2\", edgewidth=2.5),\n", " x=[nan for _ in x_bar + x_dot] + x_area,\n", " y=[nan for _ in y_bar + y_dot] + y_area,\n", " fill=[nan for _ in f_bar + f_dot] + f_area,\n", " )\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " \"axes.spines.left\": False,\n", " \"axes.spines.top\": False,\n", " \"axes.spines.right\": False,\n", " \"xtick.labelsize\": 14,\n", " })\n", " .layout(size=(9, 1.25), engine=None)\n", " .scale(\n", " fill=None,\n", " x=so.Continuous().tick(at=[0, 1, 2.5, 3.5, 4.8, 5.8]).label(\n", " like={\n", " 0: True, 1: False, 2.5: True, 3.5: False, 4.8: True, 5.8: False\n", " }.get,\n", " ),\n", " y=so.Continuous().tick(count=0),\n", " )\n", ")" ] }, { "cell_type": "raw", "id": "119741b0-9eca-45a1-983e-35effc49c7fa", "metadata": {}, "source": [ ".. _marker_property:\n", "\n", "marker\n", "~~~~~~\n", "\n", "The `marker` property is relevant for dot marks and some line marks. The API for specifying markers is very flexible, as detailed in the matplotlib API docs: :mod:`matplotlib.markers`." ] }, { "cell_type": "code", "execution_count": null, "id": "0ba9c5aa-3d9c-47c7-8aee-5851e1f3c4dd", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "marker_plot = (\n", " so.Plot()\n", " .scale(marker=None, y=so.Continuous().tick(count=0))\n", " .layout(size=(10, .5), engine=None)\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " \"axes.spines.left\": False,\n", " \"axes.spines.top\": False,\n", " \"axes.spines.right\": False,\n", " \"xtick.labelsize\":12,\n", " \"axes.xmargin\": .02,\n", " })\n", "\n", ")\n", "marker_mark = so.Dot(pointsize=15, color=\".2\", stroke=1.5)" ] }, { "cell_type": "raw", "id": "3c07a874-18a1-485a-8d65-70ea3f246340", "metadata": {}, "source": [ "Markers can be specified using a number of simple string codes:" ] }, { "cell_type": "code", "execution_count": null, "id": "6a764efd-df55-412b-8a01-8eba6f897893", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "marker_codes = [\n", " \"o\", \"^\", \"v\", \"<\", \">\",\"s\", \"D\", \"d\", \"p\", \"h\", \"H\", \"8\",\n", " \"X\", \"*\", \".\", \"P\", \"x\", \"+\", \"1\", \"2\", \"3\", \"4\", \"|\", \"_\",\n", "]\n", "x, y = [f\"'{m}'\" for m in marker_codes], [0] * len(marker_codes)\n", "marker_objs = [mpl.markers.MarkerStyle(m) for m in marker_codes]\n", "marker_plot.add(marker_mark, marker=marker_objs, x=x, y=y).plot()" ] }, { "cell_type": "raw", "id": "1c614f08-3aa4-450d-bfe2-3295c29155d5", "metadata": {}, "source": [ "They can also be programatically generated using a `(num_sides, fill_style, angle)` tuple:" ] }, { "cell_type": "code", "execution_count": null, "id": "c9c1efe7-33e1-4add-9c4e-567d8dfbb821", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "marker_codes = [\n", " (4, 0, 0), (4, 0, 45), (8, 0, 0),\n", " (4, 1, 0), (4, 1, 45), (8, 1, 0),\n", " (4, 2, 0), (4, 2, 45), (8, 2, 0),\n", "]\n", "x, y = [f\"{m}\" for m in marker_codes], [0] * len(marker_codes)\n", "marker_objs = [mpl.markers.MarkerStyle(m) for m in marker_codes]\n", "marker_plot.add(marker_mark, marker=marker_objs, x=x, y=y).plot()" ] }, { "cell_type": "raw", "id": "dc518508-cb08-4508-a7f3-5762841da6fc", "metadata": {}, "source": [ "See the matplotlib docs for additional formats, including mathtex character codes (`'$...$'`) and arrays of vertices.\n", "\n", "A marker property is always mapped with a nominal scale; there is no inherent ordering to the different shapes. If no scale is provided, the plot will programmatically generate a suitably large set of unique markers:" ] }, { "cell_type": "code", "execution_count": null, "id": "3466dc10-07a5-470f-adac-c3c05326945d", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "from seaborn._core.properties import Marker\n", "n = 14\n", "marker_objs = Marker()._default_values(n)\n", "x, y = list(map(str, range(n))), [0] * n\n", "marker_plot.add(marker_mark, marker=marker_objs, x=x, y=y).plot()" ] }, { "cell_type": "raw", "id": "30916c65-6d4c-4294-a5e2-58af8b9392f3", "metadata": {}, "source": [ "While this ensures that the shapes are technically distinct, bear in mind that — in most cases — it will be difficult to tell the markers apart if more than a handful are used in a single plot.\n", "\n", ".. note::\n", " The default marker scale is subject to change in future releases to improve discriminability." ] }, { "cell_type": "raw", "id": "3b1d0630-808a-4099-8bd0-768718f86f72", "metadata": {}, "source": [ ".. _linestyle_property:\n", "\n", "linestyle, edgestyle\n", "~~~~~~~~~~~~~~~~~~~~\n", "\n", "The `linestyle` property is relevant to line marks, and the `edgestyle` property is relevant to a number of marks with \"edges. Both properties determine the \"dashing\" of a line in terms of on-off segments.\n", "\n", "Dashes can be specified with a small number of shorthand codes (`'-'`, `'--'`, `'-.'`, and `':'`) or programatically using `(on, off, ...)` tuples. In the tuple specification, the unit is equal to the linewidth:" ] }, { "cell_type": "code", "execution_count": null, "id": "33a729db-84e4-4619-bd1a-1f60c77f7073", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "xx = np.linspace(0, 1, 100)\n", "dashes = [\"-\", \"--\", \"-.\", \":\", (6, 2), (2, 1), (.5, .5), (4, 1, 2, 1)] \n", "dash_data = (\n", " pd.DataFrame({i: xx for i in range(len(dashes))})\n", " .stack()\n", " .reset_index(1)\n", " .set_axis([\"y\", \"x\"], axis=1)\n", " .reset_index(drop=True)\n", ")\n", "(\n", " so.Plot(dash_data, \"x\", \"y\", linestyle=\"y\")\n", " .add(so.Line(linewidth=1.7, color=\".2\"), legend=None)\n", " .scale(\n", " linestyle=dashes,\n", " x=so.Continuous().tick(count=0),\n", " y=so.Continuous().tick(every=1).label(like={\n", " i: f\"'$\\mathtt{{{pat}}}$'\" if isinstance(pat, str) else pat\n", " for i, pat in enumerate(dashes)\n", " }.get)\n", " )\n", " .label(x=\"\", y=\"\")\n", " .limit(x=(0, 1), y=(7.5, -0.5))\n", " .layout(size=(9, 2.5), engine=None)\n", " .theme({\n", " **axes_style(\"white\"),\n", " **no_spines,\n", " \"ytick.labelsize\": 12,\n", " })\n", ")" ] }, { "cell_type": "raw", "id": "41063f3b-a207-4f03-a606-78e2826be522", "metadata": {}, "source": [ "Size properties\n", "---------------" ] }, { "cell_type": "raw", "id": "7a909d91-9d60-4e95-a855-18b2779f19ce", "metadata": {}, "source": [ ".. _pointsize_property:\n", "\n", "pointsize\n", "~~~~~~~~~\n", "\n", "The `pointsize` property is relevant to dot marks and to line marks that can show markers at individual data points. The units correspond to the diameter of the mark in points.\n", "\n", "Note that, while the parameterization corresponds to diameter, scales will be applied with a square root transform so that data values are linearly proportional to area:" ] }, { "cell_type": "code", "execution_count": null, "id": "b55b106d-ba14-43ec-ab9b-5d7a04fb813c", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "x = np.arange(1, 21)\n", "y = [0 for _ in x]\n", "(\n", " so.Plot(x, y)\n", " .add(so.Dots(color=\".2\", stroke=1), pointsize=x)\n", " .layout(size=(9, .5), engine=None)\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"left\", \"right\", \"top\"]},\n", " \"xtick.labelsize\": 12,\n", " \"axes.xmargin\": .025,\n", " })\n", " .scale(\n", " pointsize=None,\n", " x=so.Continuous().tick(every=1),\n", " y=so.Continuous().tick(count=0),\n", " )\n", ")" ] }, { "cell_type": "raw", "id": "66660d74-0252-4cb1-960a-c2c4823bb0e6", "metadata": {}, "source": [ ".. _linewidth_property:\n", "\n", "linewidth\n", "~~~~~~~~~\n", "\n", "The `linewidth` property is relevant to line marks and determines their thickness. The value should be non-negative and has point units:" ] }, { "cell_type": "code", "execution_count": null, "id": "a77c60d5-0d21-43a5-ab8c-f3f4abbc70ad", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "lw = np.arange(0.5, 5, .5)\n", "x = [i for i in [0, 1] for _ in lw]\n", "y = [*lw, *lw]\n", "(\n", " so.Plot(x=x, y=y, linewidth=y)\n", " .add(so.Line(color=\".2\"))\n", " .limit(y=(4.9, .1))\n", " .layout(size=(9, 1.4), engine=None)\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"bottom\", \"right\", \"top\"]},\n", " \"xtick.labelsize\": 12,\n", " \"axes.xmargin\": .015,\n", " \"ytick.labelsize\": 12,\n", " })\n", " .scale(\n", " linewidth=None,\n", " x=so.Continuous().tick(count=0),\n", " y=so.Continuous().tick(every=1, between=(.5, 4.5), minor=1),\n", " )\n", ")" ] }, { "cell_type": "raw", "id": "dcbdfcb9-d55e-467a-8514-bdb4cc2bec90", "metadata": {}, "source": [ ".. _edgewidth_property:\n", "\n", "edgewidth\n", "~~~~~~~~~\n", "\n", "The `edgewidth` property is akin to `linewidth` but applies to marks with an edge/fill rather than to lines. It also has a different default range when used in a scale. The units are the same:" ] }, { "cell_type": "code", "execution_count": null, "id": "7a1f1d5a-a2d5-4b8e-a172-73104f5ec715", "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "x = np.arange(0, 21) / 5\n", "y = [0 for _ in x]\n", "edge_plot = (\n", " so.Plot(x, y)\n", " .layout(size=(9, .5), engine=None)\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"left\", \"right\", \"top\"]},\n", " \"xtick.labelsize\": 12,\n", " \"axes.xmargin\": .02,\n", " })\n", " .scale(\n", " x=so.Continuous().tick(every=1, minor=4),\n", " y=so.Continuous().tick(count=0),\n", " )\n", ")" ] }, { "cell_type": "code", "execution_count": null, "id": "ba70ed6c-d902-41b0-a043-d8f27bf65e9b", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "(\n", " edge_plot\n", " .add(so.Dot(color=\".75\", edgecolor=\".2\", marker=\"o\", pointsize=14), edgewidth=x)\n", " .scale(edgewidth=None)\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "98a25a16-67fa-4467-a425-6a78a17c63ab", "metadata": {}, "source": [ ".. _stroke_property:\n", "\n", "stroke\n", "~~~~~~\n", "\n", "The `stroke` property is akin to `edgewidth` but applies when a dot mark is defined by its stroke rather than its fill. It also has a slightly different default scale range, but otherwise behaves similarly:" ] }, { "cell_type": "code", "execution_count": null, "id": "f73a0428-a787-4f21-8098-848eb1c816fb", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "(\n", " edge_plot\n", " .add(so.Dot(color=\".2\", marker=\"x\", pointsize=11), stroke=x)\n", " .scale(stroke=None)\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "c2ca33db-df52-4958-889a-320b4833a0d7", "metadata": {}, "source": [ "Text properties\n", "---------------" ] }, { "cell_type": "raw", "id": "b75af2fe-4d81-407c-9858-23362710f25f", "metadata": {}, "source": [ ".. _horizontalalignment_property:\n", "\n", ".. _verticalalignment_property:\n", "\n", "halign, valign\n", "~~~~~~~~~~~~~~\n", "\n", "The `halign` and `valign` properties control the *horizontal* and *vertical* alignment of text marks. The options for horizontal alignment are `'left'`, `'right'`, and `'center'`, while the options for vertical alignment are `'top'`, `'bottom'`, `'center'`, `'baseline'`, and `'center_baseline'`." ] }, { "cell_type": "code", "execution_count": null, "id": "e9588309-bee4-4b97-b428-eb91ea582105", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "x = [\"left\", \"right\", \"top\", \"bottom\", \"baseline\", \"center\"]\n", "ha = x[:2] + [\"center\"] * 4\n", "va = [\"center_baseline\"] * 2 + x[2:]\n", "y = np.zeros(len(x))\n", "(\n", " so.Plot(x=[f\"'{_x_}'\" for _x_ in x], y=y, halign=ha, valign=va)\n", " .add(so.Dot(marker=\"+\", color=\"r\", alpha=.5, stroke=1, pointsize=24))\n", " .add(so.Text(text=\"XyZ\", fontsize=14, offset=0))\n", " .scale(y=so.Continuous().tick(at=[]), halign=None, valign=None)\n", " .limit(x=(-.25, len(x) - .75))\n", " .layout(size=(9, .6), engine=None)\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"left\", \"right\", \"top\"]},\n", " \"xtick.labelsize\": 12,\n", " \"axes.xmargin\": .015,\n", " \"ytick.labelsize\": 12,\n", " })\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "ea74c7e5-798b-47bc-bc18-9086902fb5c6", "metadata": {}, "source": [ ".. _fontsize_property:\n", "\n", "fontsize\n", "~~~~~~~~\n", "\n", "The `fontsize` property controls the size of textual marks. The value has point units:" ] }, { "cell_type": "code", "execution_count": null, "id": "c515b790-385d-4521-b14a-0769c1902928", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "from string import ascii_uppercase\n", "n = 26\n", "s = np.arange(n) + 1\n", "y = np.zeros(n)\n", "t = list(ascii_uppercase[:n])\n", "(\n", " so.Plot(x=s, y=y, text=t, fontsize=s)\n", " .add(so.Text())\n", " .scale(x=so.Nominal(), y=so.Continuous().tick(at=[]))\n", " .layout(size=(9, .5), engine=None)\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"left\", \"right\", \"top\"]},\n", " \"xtick.labelsize\": 12,\n", " \"axes.xmargin\": .015,\n", " \"ytick.labelsize\": 12,\n", " })\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "4b367f36-fb96-44fa-83a3-1cc66c7a3279", "metadata": {}, "source": [ ".. _offset_property:\n", "\n", "offset\n", "~~~~~~\n", "\n", "The `offset` property controls the spacing between a text mark and its anchor position. It applies when *not* using `center` alignment (i.e., when using left/right or top/bottom). The value has point units. " ] }, { "cell_type": "code", "execution_count": null, "id": "25a49331-9580-4578-8bdb-d0d1829dde71", "metadata": { "tags": [ "hide-input" ] }, "outputs": [], "source": [ "n = 17\n", "x = np.linspace(0, 8, n)\n", "y = np.full(n, .5)\n", "(\n", " so.Plot(x=x, y=y, offset=x)\n", " .add(so.Bar(color=\".6\", edgecolor=\"k\"))\n", " .add(so.Text(text=\"abc\", valign=\"bottom\"))\n", " .scale(\n", " x=so.Continuous().tick(every=1, minor=1),\n", " y=so.Continuous().tick(at=[]),\n", " offset=None,\n", " )\n", " .limit(y=(0, 1.5))\n", " .layout(size=(9, .5), engine=None)\n", " .theme({\n", " **axes_style(\"ticks\"),\n", " **{f\"axes.spines.{side}\": False for side in [\"left\", \"right\", \"top\"]},\n", " \"axes.xmargin\": .015,\n", " \"xtick.labelsize\": 12,\n", " \"ytick.labelsize\": 12,\n", " })\n", " .plot()\n", ")" ] }, { "cell_type": "raw", "id": "77723ffd-2da3-4ece-a97a-3c00e864c743", "metadata": {}, "source": [ "Other properties\n", "----------------" ] }, { "cell_type": "raw", "id": "287bb259-0194-4c8c-8836-5e3eb6d88e79", "metadata": {}, "source": [ ".. _property_property:\n", "\n", "text\n", "~~~~\n", "\n", "The `text` property is used to set the content of a textual mark. It is always used literally (not mapped), and cast to string when necessary.\n", "\n", "group\n", "~~~~~\n", "\n", "The `group` property is special in that it does not change anything about the mark's appearance but defines additional data subsets that transforms should operate on independently." ] }, { "cell_type": "code", "execution_count": null, "id": "f23c9251-1685-4150-b5c2-ab5b0589d8e6", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 5 } ================================================ FILE: doc/_tutorial/regression.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _regression_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Estimating regression fits\n", "==========================" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Many datasets contain multiple quantitative variables, and the goal of an analysis is often to relate those variables to each other. We :ref:`previously discussed ` functions that can accomplish this by showing the joint distribution of two variables. It can be very helpful, though, to use statistical models to estimate a simple relationship between two noisy sets of observations. The functions discussed in this chapter will do so through the common framework of linear regression.\n", "\n", "In the spirit of Tukey, the regression plots in seaborn are primarily intended to add a visual guide that helps to emphasize patterns in a dataset during exploratory data analyses. That is to say that seaborn is not itself a package for statistical analysis. To obtain quantitative measures related to the fit of regression models, you should use `statsmodels `_. The goal of seaborn, however, is to make exploring a dataset through visualization quick and easy, as doing so is just as (if not more) important than exploring a dataset through tables of statistics." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "import numpy as np\n", "import seaborn as sns\n", "import matplotlib.pyplot as plt\n", "sns.set_theme(color_codes=True)\n", "np.random.seed(sum(map(ord, \"regression\")))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Functions for drawing linear regression models\n", "----------------------------------------------\n", "\n", "The two functions that can be used to visualize a linear fit are :func:`regplot` and :func:`lmplot`.\n", "\n", "In the simplest invocation, both functions draw a scatterplot of two variables, ``x`` and ``y``, and then fit the regression model ``y ~ x`` and plot the resulting regression line and a 95% confidence interval for that regression:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.regplot(x=\"total_bill\", y=\"tip\", data=tips);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"total_bill\", y=\"tip\", data=tips);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "These functions draw similar plots, but :func:`regplot` is an :doc:`axes-level function `, and :func:`lmplot` is a figure-level function. Additionally, :func:`regplot` accepts the ``x`` and ``y`` variables in a variety of formats including simple numpy arrays, :class:`pandas.Series` objects, or as references to variables in a :class:`pandas.DataFrame` object passed to `data`. In contrast, :func:`lmplot` has `data` as a required parameter and the `x` and `y` variables must be specified as strings. Finally, only :func:`lmplot` has `hue` as a parameter.\n", "\n", "The core functionality is otherwise similar, though, so this tutorial will focus on :func:`lmplot`:.\n", "\n", "It's possible to fit a linear regression when one of the variables takes discrete values, however, the simple scatterplot produced by this kind of dataset is often not optimal:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"size\", y=\"tip\", data=tips);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "One option is to add some random noise (\"jitter\") to the discrete values to make the distribution of those values more clear. Note that jitter is applied only to the scatterplot data and does not influence the regression line fit itself:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"size\", y=\"tip\", data=tips, x_jitter=.05);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A second option is to collapse over the observations in each discrete bin to plot an estimate of central tendency along with a confidence interval:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"size\", y=\"tip\", data=tips, x_estimator=np.mean);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Fitting different kinds of models\n", "---------------------------------\n", "\n", "The simple linear regression model used above is very simple to fit, however, it is not appropriate for some kinds of datasets. The `Anscombe's quartet `_ dataset shows a few examples where simple linear regression provides an identical estimate of a relationship where simple visual inspection clearly shows differences. For example, in the first case, the linear regression is a good model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "anscombe = sns.load_dataset(\"anscombe\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"x\", y=\"y\", data=anscombe.query(\"dataset == 'I'\"),\n", " ci=None, scatter_kws={\"s\": 80});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The linear relationship in the second dataset is the same, but the plot clearly shows that this is not a good model:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"x\", y=\"y\", data=anscombe.query(\"dataset == 'II'\"),\n", " ci=None, scatter_kws={\"s\": 80});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In the presence of these kind of higher-order relationships, :func:`lmplot` and :func:`regplot` can fit a polynomial regression model to explore simple kinds of nonlinear trends in the dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"x\", y=\"y\", data=anscombe.query(\"dataset == 'II'\"),\n", " order=2, ci=None, scatter_kws={\"s\": 80});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "A different problem is posed by \"outlier\" observations that deviate for some reason other than the main relationship under study:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"x\", y=\"y\", data=anscombe.query(\"dataset == 'III'\"),\n", " ci=None, scatter_kws={\"s\": 80});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In the presence of outliers, it can be useful to fit a robust regression, which uses a different loss function to downweight relatively large residuals:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"x\", y=\"y\", data=anscombe.query(\"dataset == 'III'\"),\n", " robust=True, ci=None, scatter_kws={\"s\": 80});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When the ``y`` variable is binary, simple linear regression also \"works\" but provides implausible predictions:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips[\"big_tip\"] = (tips.tip / tips.total_bill) > .15\n", "sns.lmplot(x=\"total_bill\", y=\"big_tip\", data=tips,\n", " y_jitter=.03);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The solution in this case is to fit a logistic regression, such that the regression line shows the estimated probability of ``y = 1`` for a given value of ``x``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"total_bill\", y=\"big_tip\", data=tips,\n", " logistic=True, y_jitter=.03);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Note that the logistic regression estimate is considerably more computationally intensive (this is true of robust regression as well). As the confidence interval around the regression line is computed using a bootstrap procedure, you may wish to turn this off for faster iteration (using ``ci=None``).\n", "\n", "An altogether different approach is to fit a nonparametric regression using a `lowess smoother `_. This approach has the fewest assumptions, although it is computationally intensive and so currently confidence intervals are not computed at all:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"total_bill\", y=\"tip\", data=tips,\n", " lowess=True, line_kws={\"color\": \"C1\"});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The :func:`residplot` function can be a useful tool for checking whether the simple regression model is appropriate for a dataset. It fits and removes a simple linear regression and then plots the residual values for each observation. Ideally, these values should be randomly scattered around ``y = 0``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.residplot(x=\"x\", y=\"y\", data=anscombe.query(\"dataset == 'I'\"),\n", " scatter_kws={\"s\": 80});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "If there is structure in the residuals, it suggests that simple linear regression is not appropriate:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.residplot(x=\"x\", y=\"y\", data=anscombe.query(\"dataset == 'II'\"),\n", " scatter_kws={\"s\": 80});" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Conditioning on other variables\n", "-------------------------------\n", "\n", "The plots above show many ways to explore the relationship between a pair of variables. Often, however, a more interesting question is \"how does the relationship between these two variables change as a function of a third variable?\" This is where the main differences between :func:`regplot` and :func:`lmplot` appear. While :func:`regplot` always shows a single relationship, :func:`lmplot` combines :func:`regplot` with :class:`FacetGrid` to show multiple fits using `hue` mapping or faceting.\n", "\n", "The best way to separate out a relationship is to plot both levels on the same axes and to use color to distinguish them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"total_bill\", y=\"tip\", hue=\"smoker\", data=tips);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Unlike :func:`relplot`, it's not possible to map a distinct variable to the style properties of the scatter plot, but you can redundantly code the `hue` variable with marker shape:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"total_bill\", y=\"tip\", hue=\"smoker\", data=tips,\n", " markers=[\"o\", \"x\"], palette=\"Set1\");" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To add another variable, you can draw multiple \"facets\" with each level of the variable appearing in the rows or columns of the grid:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"total_bill\", y=\"tip\", hue=\"smoker\", col=\"time\", data=tips);" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.lmplot(x=\"total_bill\", y=\"tip\", hue=\"smoker\",\n", " col=\"time\", row=\"sex\", data=tips, height=3);" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Plotting a regression in other contexts\n", "---------------------------------------\n", "\n", "A few other seaborn functions use :func:`regplot` in the context of a larger, more complex plot. The first is the :func:`jointplot` function that we introduced in the :ref:`distributions tutorial `. In addition to the plot styles previously discussed, :func:`jointplot` can use :func:`regplot` to show the linear regression fit on the joint axes by passing ``kind=\"reg\"``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.jointplot(x=\"total_bill\", y=\"tip\", data=tips, kind=\"reg\");" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Using the :func:`pairplot` function with ``kind=\"reg\"`` combines :func:`regplot` and :class:`PairGrid` to show the linear relationship between variables in a dataset. Take care to note how this is different from :func:`lmplot`. In the figure below, the two axes don't show the same relationship conditioned on two levels of a third variable; rather, :func:`PairGrid` is used to show multiple relationships between different pairings of the variables in a dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(tips, x_vars=[\"total_bill\", \"size\"], y_vars=[\"tip\"],\n", " height=5, aspect=.8, kind=\"reg\");" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Conditioning on an additional categorical variable is built into both of these functions using the ``hue`` parameter:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.pairplot(tips, x_vars=[\"total_bill\", \"size\"], y_vars=[\"tip\"],\n", " hue=\"smoker\", height=5, aspect=.8, kind=\"reg\");" ] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/_tutorial/relational.ipynb ================================================ { "cells": [ { "cell_type": "raw", "metadata": {}, "source": [ ".. _relational_tutorial:\n", "\n", ".. currentmodule:: seaborn" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Visualizing statistical relationships\n", "=====================================\n", "\n", "Statistical analysis is a process of understanding how variables in a dataset relate to each other and how those relationships depend on other variables. Visualization can be a core component of this process because, when data are visualized properly, the human visual system can see trends and patterns that indicate a relationship.\n", "\n", "We will discuss three seaborn functions in this tutorial. The one we will use most is :func:`relplot`. This is a :doc:`figure-level function ` for visualizing statistical relationships using two common approaches: scatter plots and line plots. :func:`relplot` combines a :class:`FacetGrid` with one of two axes-level functions:\n", "\n", "- :func:`scatterplot` (with ``kind=\"scatter\"``; the default)\n", "- :func:`lineplot` (with ``kind=\"line\"``)\n", "\n", "As we will see, these functions can be quite illuminating because they use simple and easily-understood representations of data that can nevertheless represent complex dataset structures. They can do so because they plot two-dimensional graphics that can be enhanced by mapping up to three additional variables using the semantics of hue, size, and style." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "sns.set_theme(style=\"darkgrid\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "%matplotlib inline\n", "np.random.seed(sum(map(ord, \"relational\")))" ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _scatterplot_tutorial:\n", "\n", "Relating variables with scatter plots\n", "-------------------------------------\n", "\n", "The scatter plot is a mainstay of statistical visualization. It depicts the joint distribution of two variables using a cloud of points, where each point represents an observation in the dataset. This depiction allows the eye to infer a substantial amount of information about whether there is any meaningful relationship between them.\n", "\n", "There are several ways to draw a scatter plot in seaborn. The most basic, which should be used when both variables are numeric, is the :func:`scatterplot` function. In the :ref:`categorical visualization tutorial `, we will see specialized tools for using scatterplots to visualize categorical data. The :func:`scatterplot` is the default ``kind`` in :func:`relplot` (it can also be forced by setting ``kind=\"scatter\"``):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "tips = sns.load_dataset(\"tips\")\n", "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "While the points are plotted in two dimensions, another dimension can be added to the plot by coloring the points according to a third variable. In seaborn, this is referred to as using a \"hue semantic\", because the color of the point gains meaning:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", hue=\"smoker\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To emphasize the difference between the classes, and to improve accessibility, you can use a different marker style for each class:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips,\n", " x=\"total_bill\", y=\"tip\", hue=\"smoker\", style=\"smoker\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's also possible to represent four variables by changing the hue and style of each point independently. But this should be done carefully, because the eye is much less sensitive to shape than to color:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips,\n", " x=\"total_bill\", y=\"tip\", hue=\"smoker\", style=\"time\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In the examples above, the hue semantic was categorical, so the default :ref:`qualitative palette ` was applied. If the hue semantic is numeric (specifically, if it can be cast to float), the default coloring switches to a sequential palette:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips, x=\"total_bill\", y=\"tip\", hue=\"size\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "In both cases, you can customize the color palette. There are many options for doing so. Here, we customize a sequential palette using the string interface to :func:`cubehelix_palette`:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips, \n", " x=\"total_bill\", y=\"tip\",\n", " hue=\"size\", palette=\"ch:r=-.5,l=.75\"\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The third kind of semantic variable changes the size of each point:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(data=tips, x=\"total_bill\", y=\"tip\", size=\"size\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Unlike with :func:`matplotlib.pyplot.scatter`, the literal value of the variable is not used to pick the area of the point. Instead, the range of values in data units is normalized into a range in area units. This range can be customized:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips, x=\"total_bill\", y=\"tip\",\n", " size=\"size\", sizes=(15, 200)\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "More examples for customizing how the different semantics are used to show statistical relationships are shown in the :func:`scatterplot` API examples." ] }, { "cell_type": "raw", "metadata": {}, "source": [ ".. _lineplot_tutorial:\n", "\n", "Emphasizing continuity with line plots\n", "--------------------------------------\n", "\n", "Scatter plots are highly effective, but there is no universally optimal type of visualisation. Instead, the visual representation should be adapted for the specifics of the dataset and to the question you are trying to answer with the plot.\n", "\n", "With some datasets, you may want to understand changes in one variable as a function of time, or a similarly continuous variable. In this situation, a good choice is to draw a line plot. In seaborn, this can be accomplished by the :func:`lineplot` function, either directly or with :func:`relplot` by setting ``kind=\"line\"``:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dowjones = sns.load_dataset(\"dowjones\")\n", "sns.relplot(data=dowjones, x=\"Date\", y=\"Price\", kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Aggregation and representing uncertainty\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "More complex datasets will have multiple measurements for the same value of the ``x`` variable. The default behavior in seaborn is to aggregate the multiple measurements at each ``x`` value by plotting the mean and the 95% confidence interval around the mean:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "fmri = sns.load_dataset(\"fmri\")\n", "sns.relplot(data=fmri, x=\"timepoint\", y=\"signal\", kind=\"line\")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The confidence intervals are computed using bootstrapping, which can be time-intensive for larger datasets. It's therefore possible to disable them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", errorbar=None,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Another good option, especially with larger data, is to represent the spread of the distribution at each timepoint by plotting the standard deviation instead of a confidence interval:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", errorbar=\"sd\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "To turn off aggregation altogether, set the ``estimator`` parameter to ``None`` This might produce a strange effect when the data have multiple observations at each point." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\",\n", " estimator=None,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Plotting subsets of data with semantic mappings\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "The :func:`lineplot` function has the same flexibility as :func:`scatterplot`: it can show up to three additional variables by modifying the hue, size, and style of the plot elements. It does so using the same API as :func:`scatterplot`, meaning that we don't need to stop and think about the parameters that control the look of lines vs. points in matplotlib.\n", "\n", "Using semantics in :func:`lineplot` will also determine how the data get aggregated. For example, adding a hue semantic with two levels splits the plot into two lines and error bands, coloring each to indicate which subset of the data they correspond to." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", hue=\"event\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Adding a style semantic to a line plot changes the pattern of dashes in the line by default:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\",\n", " hue=\"region\", style=\"event\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "But you can identify subsets by the markers used at each observation, either together with the dashes or instead of them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", hue=\"region\", style=\"event\",\n", " dashes=False, markers=True,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "As with scatter plots, be cautious about making line plots using multiple semantics. While sometimes informative, they can also be difficult to parse and interpret. But even when you are only examining changes across one additional variable, it can be useful to alter both the color and style of the lines. This can make the plot more accessible when printed to black-and-white or viewed by someone with color blindness:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", hue=\"event\", style=\"event\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When you are working with repeated measures data (that is, you have units that were sampled multiple times), you can also plot each sampling unit separately without distinguishing them through semantics. This avoids cluttering the legend:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri.query(\"event == 'stim'\"), kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", hue=\"region\",\n", " units=\"subject\", estimator=None,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The default colormap and handling of the legend in :func:`lineplot` also depends on whether the hue semantic is categorical or numeric:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "dots = sns.load_dataset(\"dots\").query(\"align == 'dots'\")\n", "sns.relplot(\n", " data=dots, kind=\"line\",\n", " x=\"time\", y=\"firing_rate\",\n", " hue=\"coherence\", style=\"choice\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It may happen that, even though the ``hue`` variable is numeric, it is poorly represented by a linear color scale. That's the case here, where the levels of the ``hue`` variable are logarithmically scaled. You can provide specific color values for each line by passing a list or dictionary:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "palette = sns.cubehelix_palette(light=.8, n_colors=6)\n", "sns.relplot(\n", " data=dots, kind=\"line\", \n", " x=\"time\", y=\"firing_rate\",\n", " hue=\"coherence\", style=\"choice\", palette=palette,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Or you can alter how the colormap is normalized:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from matplotlib.colors import LogNorm\n", "palette = sns.cubehelix_palette(light=.7, n_colors=6)\n", "sns.relplot(\n", " data=dots.query(\"coherence > 0\"), kind=\"line\",\n", " x=\"time\", y=\"firing_rate\",\n", " hue=\"coherence\", style=\"choice\",\n", " hue_norm=LogNorm(),\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "The third semantic, size, changes the width of the lines:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=dots, kind=\"line\",\n", " x=\"time\", y=\"firing_rate\",\n", " size=\"coherence\", style=\"choice\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "While the ``size`` variable will typically be numeric, it's also possible to map a categorical variable with the width of the lines. Be cautious when doing so, because it will be difficult to distinguish much more than \"thick\" vs \"thin\" lines. However, dashes can be hard to perceive when lines have high-frequency variability, so using different widths may be more effective in that case:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=dots, kind=\"line\",\n", " x=\"time\", y=\"firing_rate\",\n", " hue=\"coherence\", size=\"choice\", palette=palette,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Controlling sorting and orientation\n", "~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~\n", "\n", "Because :func:`lineplot` assumes that you are most often trying to draw ``y`` as a function of ``x``, the default behavior is to sort the data by the ``x`` values before plotting. However, this can be disabled:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "healthexp = sns.load_dataset(\"healthexp\").sort_values(\"Year\")\n", "sns.relplot(\n", " data=healthexp, kind=\"line\",\n", " x=\"Spending_USD\", y=\"Life_Expectancy\", hue=\"Country\",\n", " sort=False\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "It's also possible to sort (and aggregate) along the y axis:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"signal\", y=\"timepoint\", hue=\"event\",\n", " orient=\"y\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "Showing multiple relationships with facets\n", "------------------------------------------\n", "\n", "We've emphasized in this tutorial that, while these functions *can* show several semantic variables at once, it's not always effective to do so. But what about when you do want to understand how a relationship between two variables depends on more than one other variable?\n", "\n", "The best approach may be to make more than one plot. Because :func:`relplot` is based on the :class:`FacetGrid`, this is easy to do. To show the influence of an additional variable, instead of assigning it to one of the semantic roles in the plot, use it to \"facet\" the visualization. This means that you make multiple axes and plot subsets of the data on each of them:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=tips,\n", " x=\"total_bill\", y=\"tip\", hue=\"smoker\", col=\"time\",\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "You can also show the influence of two variables this way: one by faceting on the columns and one by faceting on the rows. As you start adding more variables to the grid, you may want to decrease the figure size. Remember that the size :class:`FacetGrid` is parameterized by the height and aspect ratio of *each facet*:" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "tags": [ "hide" ] }, "outputs": [], "source": [ "subject_number = fmri[\"subject\"].str[1:].astype(int)\n", "fmri= fmri.iloc[subject_number.argsort()]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri, kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", hue=\"subject\",\n", " col=\"region\", row=\"event\", height=3,\n", " estimator=None\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "When you want to examine effects across many levels of a variable, it can be a good idea to facet that variable on the columns and then \"wrap\" the facets into the rows:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "sns.relplot(\n", " data=fmri.query(\"region == 'frontal'\"), kind=\"line\",\n", " x=\"timepoint\", y=\"signal\", hue=\"event\", style=\"event\",\n", " col=\"subject\", col_wrap=5,\n", " height=3, aspect=.75, linewidth=2.5,\n", ")" ] }, { "cell_type": "raw", "metadata": {}, "source": [ "These visualizations, which are sometimes called \"lattice\" plots or \"small-multiples\", are very effective because they present the data in a format that makes it easy for the eye to detect both overall patterns and deviations from those patterns. While you should make use of the flexibility afforded by :func:`scatterplot` and :func:`relplot`, always try to keep in mind that several simple plots are usually more effective than one complex plot." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "celltoolbar": "Tags", "kernelspec": { "display_name": "py310", "language": "python", "name": "py310" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" } }, "nbformat": 4, "nbformat_minor": 4 } ================================================ FILE: doc/api.rst ================================================ .. _api_ref: API reference ============= .. currentmodule:: seaborn.objects .. _objects_api: Objects interface ----------------- Plot object ~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :template: plot :nosignatures: Plot Mark objects ~~~~~~~~~~~~ .. rubric:: Dot marks .. autosummary:: :toctree: generated/ :template: object :nosignatures: Dot Dots .. rubric:: Line marks .. autosummary:: :toctree: generated/ :template: object :nosignatures: Line Lines Path Paths Dash Range .. rubric:: Bar marks .. autosummary:: :toctree: generated/ :template: object :nosignatures: Bar Bars .. rubric:: Fill marks .. autosummary:: :toctree: generated/ :template: object :nosignatures: Area Band .. rubric:: Text marks .. autosummary:: :toctree: generated/ :template: object :nosignatures: Text Stat objects ~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :template: object :nosignatures: Agg Est Count Hist KDE Perc PolyFit Move objects ~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :template: object :nosignatures: Dodge Jitter Norm Stack Shift Scale objects ~~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :template: scale :nosignatures: Boolean Continuous Nominal Temporal Base classes ~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :template: object :nosignatures: Mark Stat Move Scale .. currentmodule:: seaborn Function interface ------------------ .. _relational_api: Relational plots ~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: relplot scatterplot lineplot .. _distribution_api: Distribution plots ~~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: displot histplot kdeplot ecdfplot rugplot distplot .. _categorical_api: Categorical plots ~~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: catplot stripplot swarmplot boxplot violinplot boxenplot pointplot barplot countplot .. _regression_api: Regression plots ~~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: lmplot regplot residplot .. _matrix_api: Matrix plots ~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: heatmap clustermap .. _grid_api: Multi-plot grids ---------------- Facet grids ~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: FacetGrid Pair grids ~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: pairplot PairGrid Joint grids ~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: jointplot JointGrid .. _style_api: Themeing -------- .. autosummary:: :toctree: generated/ :nosignatures: set_theme axes_style set_style plotting_context set_context set_color_codes reset_defaults reset_orig set .. _palette_api: Color palettes -------------- .. autosummary:: :toctree: generated/ :nosignatures: set_palette color_palette husl_palette hls_palette cubehelix_palette dark_palette light_palette diverging_palette blend_palette xkcd_palette crayon_palette mpl_palette Palette widgets ~~~~~~~~~~~~~~~ .. autosummary:: :toctree: generated/ :nosignatures: choose_colorbrewer_palette choose_cubehelix_palette choose_light_palette choose_dark_palette choose_diverging_palette Utility functions ----------------- .. autosummary:: :toctree: generated/ :nosignatures: despine move_legend saturate desaturate set_hls_values load_dataset get_dataset_names get_data_home ================================================ FILE: doc/citing.rst ================================================ .. _citing: Citing and logo =============== Citing seaborn -------------- If seaborn is integral to a scientific publication, please cite it. A paper describing seaborn has been published in the `Journal of Open Source Software `_: Waskom, M. L., (2021). seaborn: statistical data visualization. Journal of Open Source Software, 6(60), 3021, https://doi.org/10.21105/joss.03021. Here is a ready-made BibTeX entry: .. highlight:: none :: @article{Waskom2021, doi = {10.21105/joss.03021}, url = {https://doi.org/10.21105/joss.03021}, year = {2021}, publisher = {The Open Journal}, volume = {6}, number = {60}, pages = {3021}, author = {Michael L. Waskom}, title = {seaborn: statistical data visualization}, journal = {Journal of Open Source Software} } In most situations where seaborn is cited, a citation to `matplotlib `_ would also be appropriate. Logo files ---------- Additional logo files, including hi-res PNGs and images suitable for use over a dark background, are available `on GitHub `_. Wide logo ~~~~~~~~~ .. image:: _static/logo-wide-lightbg.svg :width: 400px Tall logo ~~~~~~~~~ .. image:: _static/logo-tall-lightbg.svg :width: 150px Logo mark ~~~~~~~~~ .. image:: _static/logo-mark-lightbg.svg :width: 150px Credit to `Matthias Bussonnier `_ for the initial design and implementation of the logo. ================================================ FILE: doc/conf.py ================================================ # Configuration file for the Sphinx documentation builder. # # This file only contains a selection of the most common options. For a full # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys import time import seaborn from seaborn._core.properties import PROPERTIES sys.path.insert(0, os.path.abspath('sphinxext')) # -- Project information ----------------------------------------------------- project = 'seaborn' copyright = f'2012-{time.strftime("%Y")}' author = 'Michael Waskom' version = release = seaborn.__version__ # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (amed 'sphinx.ext.*') or your custom # ones. extensions = [ 'sphinx.ext.autodoc', 'sphinx.ext.doctest', 'sphinx.ext.coverage', 'sphinx.ext.mathjax', 'sphinx.ext.autosummary', 'sphinx.ext.intersphinx', 'matplotlib.sphinxext.plot_directive', 'gallery_generator', 'tutorial_builder', 'numpydoc', 'sphinx_copybutton', 'sphinx_issues', 'sphinx_design', ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The root document. root_doc = 'index' # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'docstrings', 'nextgen', 'Thumbs.db', '.DS_Store'] # The reST default role (used for this markup: `text`) to use for all documents. default_role = 'literal' # Generate the API documentation when building autosummary_generate = True numpydoc_show_class_members = False # Sphinx-issues configuration issues_github_path = 'mwaskom/seaborn' # Include the example source for plots in API docs plot_include_source = True plot_formats = [('png', 90)] plot_html_show_formats = False plot_html_show_source_link = False # Don't add a source link in the sidebar html_show_sourcelink = False # Control the appearance of type hints autodoc_typehints = "none" autodoc_typehints_format = "short" # Allow shorthand references for main function interface rst_prolog = """ .. currentmodule:: seaborn """ # Define replacements (used in whatsnew bullets) rst_epilog = r""" .. role:: raw-html(raw) :format: html .. role:: raw-latex(raw) :format: latex .. |API| replace:: :raw-html:`API` :raw-latex:`{\small\sc [API]}` .. |Defaults| replace:: :raw-html:`Defaults` :raw-latex:`{\small\sc [Defaults]}` .. |Docs| replace:: :raw-html:`Docs` :raw-latex:`{\small\sc [Docs]}` .. |Feature| replace:: :raw-html:`Feature` :raw-latex:`{\small\sc [Feature]}` .. |Enhancement| replace:: :raw-html:`Enhancement` :raw-latex:`{\small\sc [Enhancement]}` .. |Fix| replace:: :raw-html:`Fix` :raw-latex:`{\small\sc [Fix]}` .. |Build| replace:: :raw-html:`Build` :raw-latex:`{\small\sc [Deps]}` """ # noqa rst_epilog += "\n".join([ f".. |{key}| replace:: :ref:`{key} <{val.__class__.__name__.lower()}_property>`" for key, val in PROPERTIES.items() ]) # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'pydata_sphinx_theme' # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named 'default.css' will overwrite the builtin 'default.css'. html_static_path = ['_static', 'example_thumbs'] for path in html_static_path: if not os.path.exists(path): os.makedirs(path) html_css_files = [f'css/custom.css?v={seaborn.__version__}'] html_logo = "_static/logo-wide-lightbg.svg" html_favicon = "_static/favicon.ico" html_theme_options = { "icon_links": [ { "name": "GitHub", "url": "https://github.com/mwaskom/seaborn", "icon": "fab fa-github", "type": "fontawesome", }, { "name": "StackOverflow", "url": "https://stackoverflow.com/tags/seaborn", "icon": "fab fa-stack-overflow", "type": "fontawesome", }, { "name": "Twitter", "url": "https://twitter.com/michaelwaskom", "icon": "fab fa-twitter", "type": "fontawesome", }, ], "show_prev_next": False, "navbar_start": ["navbar-logo"], "navbar_end": ["navbar-icon-links"], "header_links_before_dropdown": 8, } html_context = { "default_mode": "light", } html_sidebars = { "index": [], "examples/index": [], "**": ["sidebar-nav-bs.html"], } # -- Intersphinx ------------------------------------------------ intersphinx_mapping = { 'numpy': ('https://numpy.org/doc/stable/', None), 'scipy': ('https://docs.scipy.org/doc/scipy/', None), 'matplotlib': ('https://matplotlib.org/stable', None), 'pandas': ('https://pandas.pydata.org/pandas-docs/stable/', None), 'statsmodels': ('https://www.statsmodels.org/stable/', None) } ================================================ FILE: doc/example_thumbs/.gitkeep ================================================ ================================================ FILE: doc/faq.rst ================================================ .. currentmodule:: seaborn Frequently asked questions ========================== This is a collection of answers to questions that are commonly raised about seaborn. Getting started --------------- .. _faq_cant_import: I've installed seaborn, why can't I import it? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *It looks like you successfully installed seaborn by doing* `pip install seaborn` *but it cannot be imported. You get an error like "ModuleNotFoundError: No module named 'seaborn'" when you try.* This is probably not a `seaborn` problem, *per se*. If you have multiple Python environments on your computer, it is possible that you did `pip install` in one environment and tried to import the library in another. On a unix system, you could check whether the terminal commands `which pip`, `which python`, and (if applicable) `which jupyter` point to the same `bin/` directory. If not, you'll need to sort out the definition of your `$PATH` variable. Two alternate patterns for installing with `pip` may also be more robust to this problem: - Invoke `pip` on the command line with `python -m pip install ` rather than `pip install ` - Use `%pip install ` in a Jupyter notebook to install it in the same place as the kernel .. _faq_import_fails: I can't import seaborn, even though it's definitely installed! ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You've definitely installed seaborn in the right place, but importing it produces a long traceback and a confusing error message, perhaps something like* `ImportError: DLL load failed: The specified module could not be found`. Such errors usually indicate a problem with the way Python libraries are using compiled resources. Because seaborn is pure Python, it won't directly encounter these problems, but its dependencies (numpy, scipy, matplotlib, and pandas) might. To fix the issue, you'll first need to read through the traceback and figure out which dependency was being imported at the time of the error. Then consult the installation documentation for the relevant package, which might have advice for getting an installation working on your specific system. The most common culprit of these issues is scipy, which has many compiled components. Starting in seaborn version 0.12, scipy is an optional dependency, which should help to reduce the frequency of these issues. .. _faq_no_plots: Why aren't my plots showing up? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You're calling seaborn functions — maybe in a terminal or IDE with an integrated IPython console — but not seeing any plots.)* In matplotlib, there is a distinction between *creating* a figure and *showing* it, and in some cases it's necessary to explicitly call :func:`matplotlib.pyplot.show` at the point when you want to see the plot. Because that command blocks by default and is not always desired (for instance, you may be executing a script that saves files to disk) seaborn does not deviate from standard matplotlib practice here. Yet most of the examples in the seaborn docs do not have this line, because there are multiple ways to avoid needing it. In a Jupyter notebook with the `"inline" `_ (default) or `"widget" `_ backends, :func:`matplotlib.pyplot.show` is automatically called after executing a cell, so any figures will appear in the cell's outputs. You can also activate a more interactive experience by executing `%matplotlib` in any Jupyter or IPython interface or by calling :func:`matplotlib.pyplot.ion` anywhere in Python. Both methods will configure matplotlib to show or update the figure after every plotting command. .. _faq_repl_output: Why is something printed after every notebook cell? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You're using seaborn in a Jupyter notebook, and every cell prints something like or before showing the plot.* Jupyter notebooks will show the result of the final statement in the cell as part of its output, and each of seaborn's plotting functions return a reference to the matplotlib or seaborn object that contain the plot. If this is bothersome, you can suppress this output in a few ways: - Always assign the result of the final statement to a variable (e.g. `ax = sns.histplot(...)`) - Add a semicolon to the end of the final statement (e.g. `sns.histplot(...);`) - End every cell with a function that has no return value (e.g. `plt.show()`, which isn't needed but also causes no problems) - Add `cell metadata tags `_, if you're converting the notebook to a different representation .. _faq_inline_dpi: Why do the plots look fuzzy in a Jupyter notebook? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The default "inline" backend (defined by `IPython `_) uses an unusually low dpi (`"dots per inch" `_) for figure output. This is a space-saving measure: lower dpi figures take up less disk space. (Also, lower dpi inline graphics appear *physically* smaller because they are represented as `PNGs `_, which do not exactly have a concept of resolution.) So one faces an economy/quality tradeoff. You can increase the DPI by resetting the rc parameters through the matplotlib API, using :: plt.rcParams.update({"figure.dpi": 96}) Or do it as you activate the seaborn theme:: sns.set_theme(rc={"figure.dpi": 96}) If you have a high pixel-density monitor, you can make your plots sharper using "retina mode":: %config InlineBackend.figure_format = "retina" This won't change the apparent size of your plots in a Jupyter interface, but they might appear very large in other contexts (i.e. on GitHub). And they will take up 4x the disk space. Alternatively, you can make SVG plots:: %config InlineBackend.figure_format = "svg" This will configure matplotlib to emit `vector graphics `_ with "infinite resolution". The downside is that file size will now scale with the number and complexity of the artists in your plot, and in some cases (e.g., a large scatterplot matrix) the load will impact browser responsiveness. Tricky concepts --------------- .. _faq_function_levels: What do "figure-level" and "axes-level" mean? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You've encountered the term "figure-level" or "axes-level", maybe in the seaborn docs, StackOverflow answer, or GitHub thread, but you don't understand what it means.* In brief, all plotting functions in seaborn fall into one of two categories: - "axes-level" functions, which plot onto a single subplot that may or may not exist at the time the function is called - "figure-level" functions, which internally create a matplotlib figure, potentially including multiple subplots This design is intended to satisfy two objectives: - seaborn should offer functions that are "drop-in" replacements for matplotlib methods - seaborn should be able to produce figures that show "facets" or marginal distributions on distinct subplots The figure-level functions always combine one or more axes-level functions with an object that manages the layout. So, for example, :func:`relplot` is a figure-level function that combines either :func:`scatterplot` or :func:`lineplot` with a :class:`FacetGrid`. In contrast, :func:`jointplot` is a figure-level function that can combine multiple different axes-level functions — :func:`scatterplot` and :func:`histplot` by default — with a :class:`JointGrid`. If all you're doing is creating a plot with a single seaborn function call, this is not something you need to worry too much about. But it becomes relevant when you want to customize at a level beyond what the API of each function offers. It is also the source of various other points of confusion, so it is an important distinction understand (at least broadly) and keep in mind. This is explained in more detail in the :doc:`tutorial ` and in `this blog post `_. .. _faq_categorical_plots: What is a "categorical plot" or "categorical function"? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Next to the figure-level/axes-level distinction, this concept is probably the second biggest source of confusing behavior. Several :ref:`seaborn functions ` are referred to as "categorical" because they are designed to support a use-case where either the x or y variable in a plot is categorical (that is, the variable takes a finite number of potentially non-numeric values). At the time these functions were written, matplotlib did not have any direct support for non-numeric data types. So seaborn internally builds a mapping from unique values in the data to 0-based integer indexes, which is what it passes to matplotlib. If your data are strings, that's great, and it more-or-less matches how `matplotlib now handles `_ string-typed data. But a potential gotcha is that these functions *always do this by default*, even if both the x and y variables are numeric. This gives rise to a number of confusing behaviors, especially when mixing categorical and non-categorical plots (e.g., a combo bar-and-line plot). The v0.13 release added a `native_scale` parameter which provides control over this behavior. It is `False` by default, but setting it to `True` will preserve the original properties of the data used for categorical grouping. Specifying data --------------- .. _faq_data_format: How does my data need to be organized? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ To get the most out of seaborn, your data should have a "long-form" or "tidy" representation. In a dataframe, `this means that `_ each variable has its own column, each observation has its own row, and each value has its own cell. With long-form data, you can succinctly and exactly specify a visualization by assigning variables in the dataset (columns) to roles in the plot. Data organization is a common stumbling block for beginners, in part because data are often not collected or stored in a long-form representation. Therefore, it is often necessary to `reshape `_ the data using pandas before plotting. Data reshaping can be a complex undertaking, requiring both a solid grasp of dataframe structure and knowledge of the pandas API. Investing some time in developing this skill can pay large dividends. But while seaborn is *most* powerful when provided with long-form data, nearly every seaborn function will accept and plot "wide-form" data too. You can trigger this by passing an object to seaborn's `data=` parameter without specifying other plot variables (`x`, `y`, ...). You'll be limited when using wide-form data: each function can make only one kind of wide-form plot. In most cases, seaborn tries to match what matplotlib or pandas would do with a dataset of the same structure. Reshaping your data into long-form will give you substantially more flexibility, but it can be helpful to take a quick look at your data very early in the process, and seaborn tries to make this possible. Understanding how your data should be represented — and how to get it that way if it starts out messy — is very important for making efficient and complete use of seaborn, and it is elaborated on at length in the :doc:`user-guide `. .. _faq_pandas_requirement: Does seaborn only work with pandas? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generally speaking, no: seaborn is `quite flexible `_ about how your dataset needs to be represented. In most cases, :ref:`long-form data ` represented by multiple vector-like types can be passed directly to `x`, `y`, or other plotting parameters. Or you can pass a dictionary of vector types to `data` rather than a DataFrame. And when plotting with wide-form data, you can use a 2D numpy array or even nested lists to plot in wide-form mode. There are a couple older functions (namely, :func:`catplot` and :func:`lmplot`) that do require you to pass a :class:`pandas.DataFrame`. But at this point, they are the exception, and they will gain more flexibility over the next few release cycles. Layout problems --------------- .. _faq_figure_size: How do I change the figure size? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This is going to be more complicated than you might hope, in part because there are multiple ways to change the figure size in matplotlib, and in part because of the :ref:`figure-level/axes-level ` distinction in seaborn. In matplotlib, you can usually set the default size for all figures through the `rc parameters `_, specifically `figure.figsize`. And you can set the size of an individual figure when you create it (e.g. `plt.subplots(figsize=(w, h))`). If you're using an axes-level seaborn function, both of these will work as expected. Figure-level functions both ignore the default figure size and :ref:`parameterize the figure size differently `. When calling a figure-level function, you can pass values to `height=` and `aspect=` to set (roughly) the size of each *subplot*. The advantage here is that the size of the figure automatically adapts when you add faceting variables. But it can be confusing. Fortunately, there's a consistent way to set the exact figure size in a function-independent manner. Instead of setting the figure size when the figure is created, modify it after you plot by calling `obj.figure.set_size_inches(...)`, where `obj` is either a matplotlib axes (usually assigned to `ax`) or a seaborn `FacetGrid` (usually assigned to `g`). Note that :attr:`FacetGrid.figure` exists only on seaborn >= 0.11.2; before that you'll have to access :attr:`FacetGrid.fig`. Also, if you're making pngs (or in a Jupyter notebook), you can — perhaps surprisingly — scale all your plots up or down by :ref:`changing the dpi `. .. _faq_plot_misplaced: Why isn't seaborn drawing the plot where I tell it to? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You've explicitly created a matplotlib figure with one or more subplots and tried to draw a seaborn plot on it, but you end up with an extra figure and a blank subplot. Perhaps your code looks something like* :: f, ax = plt.subplots() sns.catplot(..., ax=ax) This is a :ref:`figure-level/axes-level ` gotcha. Figure-level functions always create their own figure, so you can't direct them towards an existing axes the way you can with axes-level functions. Most functions will warn you when this happens, suggest the appropriate axes-level function, and ignore the `ax=` parameter. A few older functions might put the plot where you want it (because they internally pass `ax` to their axes-level function) while still creating an extra figure. This latter behavior should be considered a bug, and it is not to be relied on. The way things currently work, you can either set up the matplotlib figure yourself, or you can use a figure-level function, but you can't do both at the same time. .. _faq_categorical_line: Why can't I draw a line over a bar/box/strip/violin plot? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You're trying to create a single plot using multiple seaborn functions, perhaps by drawing a lineplot or regplot over a barplot or violinplot. You expect the line to go through the mean value for each box (etc.), but it looks to be misalgined, or maybe it's all the way off to the side.* You are trying to combine a :ref:`"categorical plot" ` with another plot type. If your `x` variable has numeric values, it seems like this should work. But recall: seaborn's categorical plots map unique values on the categorical axis to integer indexes. So if your data have unique `x` values of 1, 6, 20, 94, the corresponding plot elements will get drawn at 0, 1, 2, 3 (and the tick labels will be changed to represent the actual value). The line or regression plot doesn't know that this has happened, so it will use the actual numeric values, and the plots won't line up at all. As of now, there are two ways to work around this. In situations where you want to draw a line, you could use the (somewhat misleadingly named) :func:`pointplot` function, which is also a "categorical" function and will use the same rules for drawing the plot. If this doesn't solve the problem (for one, it's not as visually flexible as :func:`lineplot`, you could implement the mapping from actual values to integer indexes yourself and draw the plot that way:: unique_xs = sorted(df["x"].unique()) sns.violinplot(data=df, x="x", y="y") sns.lineplot(data=df, x=df["x"].map(unique_xs.index), y="y") This is something that will be easier in a planned future release, as it will become possible to make the categorical functions treat numeric data as numeric. (As of v0.12, it's possible only in :func:`stripplot` and :func:`swarmplot`, using `native_scale=True`). How do I move the legend? ~~~~~~~~~~~~~~~~~~~~~~~~~ *When applying a semantic mapping to a plot, seaborn will automatically create a legend and add it to the figure. But the automatic choice of legend position is not always ideal.* With seaborn v0.11.2 or later, use the :func:`move_legend` function. On older versions, a common pattern was to call `ax.legend(loc=...)` after plotting. While this appears to move the legend, it actually *replaces* it with a new one, using any labeled artists that happen to be attached to the axes. This does `not consistently work `_ across plot types. And it does not propagate the legend title or positioning tweaks that are used to format a multi-variable legend. The :func:`move_legend` function is actually more powerful than its name suggests, and it can also be used to modify other `legend parameters `_ (font size, handle length, etc.) after plotting. Other customizations -------------------- .. _faq_figure_customization: How can I can I change something about the figure? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You want to make a very specific plot, and seaborn's defaults aren't doing it for you.* There's basically a four-layer hierarchy to customizing a seaborn figure: 1. Explicit seaborn function parameters 2. Passed-through matplotlib keyword arguments 3. Matplotlib axes methods 4. Matplotlib artist methods First, read through the API docs for the relevant seaborn function. Each has a lot of parameters (probably too many), and you may be able to accomplish your desired customization using seaborn's own API. But seaborn does delegate a lot of customization to matplotlib. Most functions have `**kwargs` in their signature, which will catch extra keyword arguments and pass them through to the underlying matplotlib function. For example, :func:`scatterplot` has a number of parameters, but you can also use any valid keyword argument for :meth:`matplotlib.axes.Axes.scatter`, which it calls internally. Passing through keyword arguments lets you customize the artists that represent data, but often you will want to customize other aspects of the figure, such as labels, ticks, and titles. You can do this by calling methods on the object that seaborn's plotting functions return. Depending on whether you're calling an :ref:`axes-level or figure-level function `, this may be a :class:`matplotlib.axes.Axes` object or a seaborn wrapper (such as :class:`seaborn.FacetGrid`). Both kinds of objects have numerous methods that you can call to customize nearly anything about the figure. The easiest thing is usually to call :meth:`matplotlib.axes.Axes.set` or :meth:`seaborn.FacetGrid.set`, which let you modify multiple attributes at once, e.g.:: ax = sns.scatterplot(...) ax.set( xlabel="The x label", ylabel="The y label", title="The title" xlim=(xmin, xmax), xticks=[...], xticklabels=[...], ) Finally, the deepest customization may require you to reach "into" the matplotlib axes and tweak the artists that are stored on it. These will be in artist lists, such as `ax.lines`, `ax.collections`, `ax.patches`, etc. *Warning:* Neither matplotlib nor seaborn consider the specific artists produced by their plotting functions to be part of stable API. Because it's not possible to gracefully warn about upcoming changes to the artist types or the order in which they are stored, code that interacts with these attributes could break unexpectedly. With that said, seaborn does try hard to avoid making this kind of change. .. _faq_matplotlib_requirement: Wait, I need to learn how to use matplotlib too? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ It really depends on how much customization you need. You can certainly perform a lot of exploratory data analysis while primarily or exclusively interacting with the seaborn API. But, if you're polishing a figure for a presentation or publication, you'll likely find yourself needing to understand at least a little bit about how matplotlib works. Matplotlib is extremely flexible, and it lets you control literally everything about a figure if you drill down far enough. Seaborn was originally designed with the idea that it would handle a specific set of well-defined operations through a very high-level API, while letting users "drop down" to matplotlib when they desired additional customization. This can be a pretty powerful combination, and it works reasonably well if you already know how to use matplotlib. But as seaborn as gained more features, it has become more feasible to learn seaborn *first*. In that situation, the need to switch APIs tends to be a bit more confusing / frustrating. This has motivated the development of seaborn's new :doc:`objects interface `, which aims to provide a more cohesive API for both high-level and low-level figure specification. Hopefully, it will alleviate the "two-library problem" as it matures. With that said, the level of deep control that matplotlib affords really can't be beat, so if you care about doing very specific things, it really is worth learning. .. _faq_object_oriented: How do I use seaborn with matplotlib's object-oriented interface? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You prefer to use matplotlib's explicit or* `"object-oriented" `_ *interface, because it makes your code easier to reason about and maintain. But the object-orient interface consists of methods on matplotlib objects, whereas seaborn offers you independent functions.* This is another case where it will be helpful to keep the :ref:`figure-level/axes-level ` distinction in mind. Axes-level functions can be used like any matplotlib axes method, but instead of calling `ax.func(...)`, you call `func(..., ax=ax)`. They also return the axes object (which they may have created, if no figure was currently active in matplotlib's global state). You can use the methods on that object to further customize the plot even if you didn't start with :func:`matplotlib.pyplot.figure` or :func:`matplotlib.pyplot.subplots`:: ax = sns.histplot(...) ax.set(...) Figure-level functions :ref:`can't be directed towards an existing figure `, but they do store the matplotlib objects on the :class:`FacetGrid` object that they return (which seaborn docs always assign to a variable named `g`). If your figure-level function created only one subplot, you can access it directly:: g = sns.displot(...) g.ax.set(...) For multiple subplots, you can either use :attr:`FacetGrid.axes` (which is always a 2D array of axes) or :attr:`FacetGrid.axes_dict` (which maps the row/col keys to the corresponding matplotlib object):: g = sns.displot(..., col=...) for col, ax in g.axes_dict.items(): ax.set(...) But if you're batch-setting attributes on all subplots, use the :meth:`FacetGrid.set` method rather than iterating over the individual axes:: g = sns.displot(...) g.set(...) To access the underlying matplotlib *figure*, use :attr:`FacetGrid.figure` on seaborn >= 0.11.2 (or :attr:`FacetGrid.fig` on any other version). .. _faq_bar_annotations: Can I annotate bar plots with the bar values? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Nothing like this is built into seaborn, but matplotlib v3.4.0 added a convenience function (:meth:`matplotlib.axes.Axes.bar_label`) that makes it relatively easy. Here are a couple of recipes; note that you'll need to use a different approach depending on whether your bars come from a :ref:`figure-level or axes-level function `:: # Axes-level ax = sns.histplot(df, x="x_var") for bars in ax.containers: ax.bar_label(bars) # Figure-level, one subplot g = sns.displot(df, x="x_var") for bars in g.ax.containers: g.ax.bar_label(bars) # Figure-level, multiple subplots g = sns.displot(df, x="x_var", col="col_var) for ax in g.axes.flat: for bars in ax.containers: ax.bar_label(bars) .. _faq_dar_mode: Can I use seaborn in dark mode? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ There's no direct support for this in seaborn, but matplotlib has a `"dark_background" `_ style-sheet that you could use, e.g.:: sns.set_theme(style="ticks", rc=plt.style.library["dark_background"]) Note that "dark_background" changes the default color palette to "Set2", and that will override any palette you define in :func:`set_theme`. If you'd rather use a different color palette, you'll have to call :func:`sns.set_palette` separately. The default :doc:`seaborn palette ` ("deep") has poor contrast against a dark background, so you'd be better off using "muted", "bright", or "pastel". Statistical inquiries --------------------- .. _faq_stat_results: Can I access the results of seaborn's statistical transformations? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Because seaborn performs some statistical operations as it builds plots (aggregating, bootstrapping, fitting regression models), some users would like access to the statistics that it computes. This is not possible: it's explicitly considered out of scope for seaborn (a visualization library) to offer an API for interrogating statistical models. If you simply want to be diligent and verify that seaborn is doing things correctly (or that it matches your own code), it's open-source, so feel free to read the code. Or, because it's Python, you can call into the private methods that calculate the stats (just don't do this in production code). But don't expect seaborn to offer features that are more at home in `scipy `_ or `statsmodels `_. .. _faq_standard_error: Can I show standard error instead of a confidence interval? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ As of v0.12, this is possible in most places, using the new `errorbar` API (see the :doc:`tutorial ` for more details). .. _faq_kde_value: Why does the y axis for a KDE plot go above 1? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ *You've estimated a probability distribution for your data using* :func:`kdeplot`, *but the y axis goes above 1. Aren't probabilities bounded by 1? Is this a bug?* This is not a bug, but it is a common confusion (about kernel density plots and probability distributions more broadly). A continuous probability distribution is defined by a `probability density function `_, which :func:`kdeplot` estimates. The probability density function does **not** output *a probability*: a continuous random variable can take an infinite number of values, so the probability of observing any *specific* value is infinitely small. You can only talk meaningfully about the probability of observing a value that falls within some *range*. The probability of observing a value that falls within the complete range of possible values is 1. Likewise, the probability density function is normalized so that the area under it (that is, the integral of the function across its domain) equals 1. If the range of likely values is small, the curve will have to go above 1 to make this possible. Common curiosities ------------------ .. _faq_import_convention: Why is seaborn imported as `sns`? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This is an obscure reference to the `namesake `_ of the library, but you can also think of it as "seaborn name space". .. _faq_seaborn_sucks: Why is ggplot so much better than seaborn? ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Good question. Probably because you get to use the word "geom" a lot, and it's fun to say. "Geom". "Geeeeeooom". ================================================ FILE: doc/index.rst ================================================ :html_theme.sidebar_secondary.remove: seaborn: statistical data visualization ======================================= .. grid:: 6 :gutter: 1 .. grid-item:: .. image:: example_thumbs/scatterplot_matrix_thumb.png :target: ./examples/scatterplot_matrix.html .. grid-item:: .. image:: example_thumbs/errorband_lineplots_thumb.png :target: examples/errorband_lineplots.html .. grid-item:: .. image:: example_thumbs/scatterplot_sizes_thumb.png :target: examples/scatterplot_sizes.html .. grid-item:: .. image:: example_thumbs/timeseries_facets_thumb.png :target: examples/timeseries_facets.html .. grid-item:: .. image:: example_thumbs/horizontal_boxplot_thumb.png :target: examples/horizontal_boxplot.html .. grid-item:: .. image:: example_thumbs/regression_marginals_thumb.png :target: examples/regression_marginals.html .. grid:: 1 1 3 3 .. grid-item:: :columns: 12 12 6 6 Seaborn is a Python data visualization library based on `matplotlib `_. It provides a high-level interface for drawing attractive and informative statistical graphics. For a brief introduction to the ideas behind the library, you can read the :doc:`introductory notes ` or the `paper `_. Visit the :doc:`installation page ` to see how you can download the package and get started with it. You can browse the :doc:`example gallery ` to see some of the things that you can do with seaborn, and then check out the :doc:`tutorials ` or :doc:`API reference ` to find out how. To see the code or report a bug, please visit the `GitHub repository `_. General support questions are most at home on `stackoverflow `_, which has a dedicated channel for seaborn. .. grid-item-card:: Contents :columns: 12 12 2 2 :class-title: sd-fs-5 :class-body: sd-pl-4 .. toctree:: :maxdepth: 1 Installing Gallery Tutorial API Releases Citing FAQ .. grid-item-card:: Features :columns: 12 12 4 4 :class-title: sd-fs-5 :class-body: sd-pl-3 * :bdg-secondary:`New` Objects: :ref:`API ` | :doc:`Tutorial ` * Relational plots: :ref:`API ` | :doc:`Tutorial ` * Distribution plots: :ref:`API ` | :doc:`Tutorial ` * Categorical plots: :ref:`API ` | :doc:`Tutorial ` * Regression plots: :ref:`API ` | :doc:`Tutorial ` * Multi-plot grids: :ref:`API ` | :doc:`Tutorial ` * Figure theming: :ref:`API ` | :doc:`Tutorial ` * Color palettes: :ref:`API ` | :doc:`Tutorial ` ================================================ FILE: doc/installing.rst ================================================ .. _installing: .. currentmodule:: seaborn Installing and getting started ------------------------------ Official releases of seaborn can be installed from `PyPI `_:: pip install seaborn The basic invocation of `pip` will install seaborn and, if necessary, its mandatory dependencies. It is possible to include optional dependencies that give access to a few advanced features:: pip install seaborn[stats] The library is also included as part of the `Anaconda `_ distribution, and it can be installed with `conda`:: conda install seaborn As the main Anaconda repository can be slow to add new releases, you may prefer using the `conda-forge `_ channel:: conda install seaborn -c conda-forge Dependencies ~~~~~~~~~~~~ Supported Python versions ^^^^^^^^^^^^^^^^^^^^^^^^^ - Python 3.10+ Mandatory dependencies ^^^^^^^^^^^^^^^^^^^^^^ - `numpy `__ - `pandas `__ - `matplotlib `__ Optional dependencies ^^^^^^^^^^^^^^^^^^^^^ - `statsmodels `__, for advanced regression plots - `scipy `__, for clustering matrices and some advanced options - `fastcluster `__, faster clustering of large matrices Quickstart ~~~~~~~~~~ Once you have seaborn installed, you're ready to get started. To test it out, you could load and plot one of the example datasets:: import seaborn as sns df = sns.load_dataset("penguins") sns.pairplot(df, hue="species") If you're working in a Jupyter notebook or an IPython terminal with `matplotlib mode `_ enabled, you should immediately see :ref:`the plot `. Otherwise, you may need to explicitly call :func:`matplotlib.pyplot.show`:: import matplotlib.pyplot as plt plt.show() While you can get pretty far with only seaborn imported, having access to matplotlib functions is often useful. The tutorials and API documentation typically assume the following imports:: import numpy as np import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt import seaborn as sns import seaborn.objects as so Debugging install issues ~~~~~~~~~~~~~~~~~~~~~~~~ The seaborn codebase is pure Python, and the library should generally install without issue. Occasionally, difficulties will arise because the dependencies include compiled code and link to system libraries. These difficulties typically manifest as errors on import with messages such as ``"DLL load failed"``. To debug such problems, read through the exception trace to figure out which specific library failed to import, and then consult the installation docs for that package to see if they have tips for your particular system. In some cases, an installation of seaborn will appear to succeed, but trying to import it will raise an error with the message ``"No module named seaborn"``. This usually means that you have multiple Python installations on your system and that your ``pip`` or ``conda`` points towards a different installation than where your interpreter lives. Resolving this issue will involve sorting out the paths on your system, but it can sometimes be avoided by invoking ``pip`` with ``python -m pip install seaborn``. Getting help ~~~~~~~~~~~~ If you think you've encountered a bug in seaborn, please report it on the `GitHub issue tracker `_. To be useful, bug reports must include the following information: - A reproducible code example that demonstrates the problem - The output that you are seeing (an image of a plot, or the error message) - A clear explanation of why you think something is wrong - The specific versions of seaborn and matplotlib that you are working with Bug reports are easiest to address if they can be demonstrated using one of the example datasets from the seaborn docs (i.e. with :func:`load_dataset`). Otherwise, it is preferable that your example generate synthetic data to reproduce the problem. If you can only demonstrate the issue with your actual dataset, you will need to share it, ideally as a csv. If you've encountered an error, searching the specific text of the message before opening a new issue can often help you solve the problem quickly and avoid making a duplicate report. Because matplotlib handles the actual rendering, errors or incorrect outputs may be due to a problem in matplotlib rather than one in seaborn. It can save time if you try to reproduce the issue in an example that uses only matplotlib, so that you can report it in the right place. But it is alright to skip this step if it's not obvious how to do it. General support questions are more at home on `stackoverflow `_, where there is a larger audience of people who will see your post and may be able to offer assistance. Your chance of getting a quick answer will be higher if you include `runnable code `_, a precise statement of what you are hoping to achieve, and a clear explanation of the problems that you have encountered. ================================================ FILE: doc/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.https://www.sphinx-doc.org/ exit /b 1 ) if "%1" == "" goto help %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% :end popd ================================================ FILE: doc/matplotlibrc ================================================ savefig.bbox : tight ================================================ FILE: doc/sphinxext/gallery_generator.py ================================================ """ Sphinx plugin to run example scripts and create a gallery page. Lightly modified from the mpld3 project. """ import os import os.path as op import re import glob import token import tokenize import shutil import warnings import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt # noqa: E402 # Python 3 has no execfile def execfile(filename, globals=None, locals=None): with open(filename, "rb") as fp: exec(compile(fp.read(), filename, 'exec'), globals, locals) RST_TEMPLATE = """ .. currentmodule:: seaborn .. _{sphinx_tag}: {docstring} .. image:: {img_file} **seaborn components used:** {components} .. literalinclude:: {fname} :lines: {end_line}- """ INDEX_TEMPLATE = """ :html_theme.sidebar_secondary.remove: .. raw:: html .. _{sphinx_tag}: Example gallery =============== {toctree} {contents} .. raw:: html
""" def create_thumbnail(infile, thumbfile, width=275, height=275, cx=0.5, cy=0.5, border=4): baseout, extout = op.splitext(thumbfile) im = matplotlib.image.imread(infile) rows, cols = im.shape[:2] x0 = int(cx * cols - .5 * width) y0 = int(cy * rows - .5 * height) xslice = slice(x0, x0 + width) yslice = slice(y0, y0 + height) thumb = im[yslice, xslice] thumb[:border, :, :3] = thumb[-border:, :, :3] = 0 thumb[:, :border, :3] = thumb[:, -border:, :3] = 0 dpi = 100 fig = plt.figure(figsize=(width / dpi, height / dpi), dpi=dpi) ax = fig.add_axes([0, 0, 1, 1], aspect='auto', frameon=False, xticks=[], yticks=[]) if all(thumb.shape): ax.imshow(thumb, aspect='auto', resample=True, interpolation='bilinear') else: warnings.warn( f"Bad thumbnail crop. {thumbfile} will be empty." ) fig.savefig(thumbfile, dpi=dpi) return fig def indent(s, N=4): """indent a string""" return s.replace('\n', '\n' + N * ' ') class ExampleGenerator: """Tools for generating an example page from a file""" def __init__(self, filename, target_dir): self.filename = filename self.target_dir = target_dir self.thumbloc = .5, .5 self.extract_docstring() with open(filename) as fid: self.filetext = fid.read() outfilename = op.join(target_dir, self.rstfilename) # Only actually run it if the output RST file doesn't # exist or it was modified less recently than the example file_mtime = op.getmtime(filename) if not op.exists(outfilename) or op.getmtime(outfilename) < file_mtime: self.exec_file() else: print(f"skipping {self.filename}") @property def dirname(self): return op.split(self.filename)[0] @property def fname(self): return op.split(self.filename)[1] @property def modulename(self): return op.splitext(self.fname)[0] @property def pyfilename(self): return self.modulename + '.py' @property def rstfilename(self): return self.modulename + ".rst" @property def htmlfilename(self): return self.modulename + '.html' @property def pngfilename(self): pngfile = self.modulename + '.png' return "_images/" + pngfile @property def thumbfilename(self): pngfile = self.modulename + '_thumb.png' return pngfile @property def sphinxtag(self): return self.modulename @property def pagetitle(self): return self.docstring.strip().split('\n')[0].strip() @property def plotfunc(self): match = re.search(r"sns\.(.+plot)\(", self.filetext) if match: return match.group(1) match = re.search(r"sns\.(.+map)\(", self.filetext) if match: return match.group(1) match = re.search(r"sns\.(.+Grid)\(", self.filetext) if match: return match.group(1) return "" @property def components(self): objects = re.findall(r"sns\.(\w+)\(", self.filetext) refs = [] for obj in objects: if obj[0].isupper(): refs.append(f":class:`{obj}`") else: refs.append(f":func:`{obj}`") return ", ".join(refs) def extract_docstring(self): """ Extract a module-level docstring """ lines = open(self.filename).readlines() start_row = 0 if lines[0].startswith('#!'): lines.pop(0) start_row = 1 docstring = '' first_par = '' line_iter = lines.__iter__() tokens = tokenize.generate_tokens(lambda: next(line_iter)) for tok_type, tok_content, _, (erow, _), _ in tokens: tok_type = token.tok_name[tok_type] if tok_type in ('NEWLINE', 'COMMENT', 'NL', 'INDENT', 'DEDENT'): continue elif tok_type == 'STRING': docstring = eval(tok_content) # If the docstring is formatted with several paragraphs, # extract the first one: paragraphs = '\n'.join(line.rstrip() for line in docstring.split('\n') ).split('\n\n') if len(paragraphs) > 0: first_par = paragraphs[0] break thumbloc = None for i, line in enumerate(docstring.split("\n")): m = re.match(r"^_thumb: (\.\d+),\s*(\.\d+)", line) if m: thumbloc = float(m.group(1)), float(m.group(2)) break if thumbloc is not None: self.thumbloc = thumbloc docstring = "\n".join([l for l in docstring.split("\n") if not l.startswith("_thumb")]) self.docstring = docstring self.short_desc = first_par self.end_line = erow + 1 + start_row def exec_file(self): print(f"running {self.filename}") plt.close('all') my_globals = {'pl': plt, 'plt': plt} execfile(self.filename, my_globals) fig = plt.gcf() fig.canvas.draw() pngfile = op.join(self.target_dir, self.pngfilename) thumbfile = op.join("example_thumbs", self.thumbfilename) self.html = f"" fig.savefig(pngfile, dpi=75, bbox_inches="tight") cx, cy = self.thumbloc create_thumbnail(pngfile, thumbfile, cx=cx, cy=cy) def toctree_entry(self): return f" ./{op.splitext(self.htmlfilename)[0]}\n\n" def contents_entry(self): return (".. raw:: html\n\n" " \n\n" "\n\n" "".format(self.htmlfilename, self.thumbfilename, self.plotfunc)) def main(app): static_dir = op.join(app.builder.srcdir, '_static') target_dir = op.join(app.builder.srcdir, 'examples') image_dir = op.join(app.builder.srcdir, 'examples/_images') thumb_dir = op.join(app.builder.srcdir, "example_thumbs") source_dir = op.abspath(op.join(app.builder.srcdir, '..', 'examples')) if not op.exists(static_dir): os.makedirs(static_dir) if not op.exists(target_dir): os.makedirs(target_dir) if not op.exists(image_dir): os.makedirs(image_dir) if not op.exists(thumb_dir): os.makedirs(thumb_dir) if not op.exists(source_dir): os.makedirs(source_dir) banner_data = [] toctree = ("\n\n" ".. toctree::\n" " :hidden:\n\n") contents = "\n\n" # Write individual example files for filename in sorted(glob.glob(op.join(source_dir, "*.py"))): ex = ExampleGenerator(filename, target_dir) banner_data.append({"title": ex.pagetitle, "url": op.join('examples', ex.htmlfilename), "thumb": op.join(ex.thumbfilename)}) shutil.copyfile(filename, op.join(target_dir, ex.pyfilename)) output = RST_TEMPLATE.format(sphinx_tag=ex.sphinxtag, docstring=ex.docstring, end_line=ex.end_line, components=ex.components, fname=ex.pyfilename, img_file=ex.pngfilename) with open(op.join(target_dir, ex.rstfilename), 'w') as f: f.write(output) toctree += ex.toctree_entry() contents += ex.contents_entry() if len(banner_data) < 10: banner_data = (4 * banner_data)[:10] # write index file index_file = op.join(target_dir, 'index.rst') with open(index_file, 'w') as index: index.write(INDEX_TEMPLATE.format(sphinx_tag="example_gallery", toctree=toctree, contents=contents)) def setup(app): app.connect('builder-inited', main) ================================================ FILE: doc/sphinxext/tutorial_builder.py ================================================ from pathlib import Path import warnings from jinja2 import Environment import yaml import numpy as np import matplotlib as mpl import seaborn as sns import seaborn.objects as so TEMPLATE = """ :notoc: .. _tutorial: User guide and tutorial ======================= {% for section in sections %} {{ section.header }} {% for page in section.pages %} .. grid:: 1 :gutter: 2 .. grid-item-card:: .. grid:: 2 .. grid-item:: :columns: 3 .. image:: ./tutorial/{{ page }}.svg :target: ./tutorial/{{ page }}.html .. grid-item:: :columns: 9 :margin: auto .. toctree:: :maxdepth: 2 tutorial/{{ page }} {% endfor %} {% endfor %} """ def main(app): content_yaml = Path(app.builder.srcdir) / "tutorial.yaml" tutorial_rst = Path(app.builder.srcdir) / "tutorial.rst" tutorial_dir = Path(app.builder.srcdir) / "tutorial" tutorial_dir.mkdir(exist_ok=True) with open(content_yaml) as fid: sections = yaml.load(fid, yaml.BaseLoader) for section in sections: title = section["title"] section["header"] = "\n".join([title, "-" * len(title)]) if title else "" env = Environment().from_string(TEMPLATE) content = env.render(sections=sections) with open(tutorial_rst, "w") as fid: fid.write(content) for section in sections: for page in section["pages"]: if ( not (svg_path := tutorial_dir / f"{page}.svg").exists() or svg_path.stat().st_mtime < Path(__file__).stat().st_mtime ): write_thumbnail(svg_path, page) def write_thumbnail(svg_path, page): with ( sns.axes_style("dark"), sns.plotting_context("notebook"), sns.color_palette("deep") ): fig = globals()[page]() for ax in fig.axes: ax.set(xticklabels=[], yticklabels=[], xlabel="", ylabel="", title="") with warnings.catch_warnings(): warnings.simplefilter("ignore") fig.tight_layout() fig.savefig(svg_path, format="svg") def introduction(): tips = sns.load_dataset("tips") fmri = sns.load_dataset("fmri").query("region == 'parietal'") penguins = sns.load_dataset("penguins") f = mpl.figure.Figure(figsize=(5, 5)) with sns.axes_style("whitegrid"): f.subplots(2, 2) sns.scatterplot( tips, x="total_bill", y="tip", hue="sex", size="size", alpha=.75, palette=["C0", ".5"], legend=False, ax=f.axes[0], ) sns.kdeplot( tips.query("size != 5"), x="total_bill", hue="size", palette="blend:C0,.5", fill=True, linewidth=.5, legend=False, common_norm=False, ax=f.axes[1], ) sns.lineplot( fmri, x="timepoint", y="signal", hue="event", errorbar=("se", 2), legend=False, palette=["C0", ".5"], ax=f.axes[2], ) sns.boxplot( penguins, x="bill_depth_mm", y="species", hue="sex", whiskerprops=dict(linewidth=1.5), medianprops=dict(linewidth=1.5), boxprops=dict(linewidth=1.5), capprops=dict(linewidth=0), width=.5, palette=["C0", ".8"], whis=5, ax=f.axes[3], ) f.axes[3].legend_ = None for ax in f.axes: ax.set(xticks=[], yticks=[]) return f def function_overview(): from matplotlib.patches import FancyBboxPatch f = mpl.figure.Figure(figsize=(7, 5)) with sns.axes_style("white"): ax = f.subplots() f.subplots_adjust(0, 0, 1, 1) ax.set_axis_off() ax.set(xlim=(0, 1), ylim=(0, 1)) deep = sns.color_palette("deep") colors = dict(relational=deep[0], distributions=deep[1], categorical=deep[2]) dark = sns.color_palette("dark") text_colors = dict(relational=dark[0], distributions=dark[1], categorical=dark[2]) functions = dict( relational=["scatterplot", "lineplot"], distributions=["histplot", "kdeplot", "ecdfplot", "rugplot"], categorical=[ "stripplot", "swarmplot", "boxplot", "violinplot", "pointplot", "barplot" ], ) pad, w, h = .06, .2, .15 xs, y = np.arange(0, 1, 1 / 3) + pad * 1.05, .7 for x, mod in zip(xs, functions): color = colors[mod] + (.2,) text_color = text_colors[mod] ax.add_artist(FancyBboxPatch((x, y), w, h, f"round,pad={pad}", color="white")) ax.add_artist(FancyBboxPatch( (x, y), w, h, f"round,pad={pad}", linewidth=1, edgecolor=text_color, facecolor=color, )) ax.text( x + w / 2, y + h / 2, f"{mod[:3]}plot\n({mod})", ha="center", va="center", size=20, color=text_color ) for i, func in enumerate(functions[mod]): x_i, y_i = x + w / 2, y - i * .1 - h / 2 - pad xy = x_i - w / 2, y_i - pad / 3 ax.add_artist( FancyBboxPatch(xy, w, h / 4, f"round,pad={pad / 3}", color="white") ) ax.add_artist(FancyBboxPatch( xy, w, h / 4, f"round,pad={pad / 3}", linewidth=1, edgecolor=text_color, facecolor=color )) ax.text(x_i, y_i, func, ha="center", va="center", size=16, color=text_color) ax.plot([x_i, x_i], [y, y_i], zorder=-100, color=text_color, lw=1) return f def data_structure(): f = mpl.figure.Figure(figsize=(7, 5)) gs = mpl.gridspec.GridSpec( figure=f, ncols=6, nrows=2, height_ratios=(1, 20), left=0, right=.35, bottom=0, top=.9, wspace=.1, hspace=.01 ) colors = [c + (.5,) for c in sns.color_palette("deep")] f.add_subplot(gs[0, :], facecolor=".8") for i in range(gs.ncols): f.add_subplot(gs[1:, i], facecolor=colors[i]) gs = mpl.gridspec.GridSpec( figure=f, ncols=2, nrows=2, height_ratios=(1, 8), width_ratios=(1, 11), left=.4, right=1, bottom=.2, top=.8, wspace=.015, hspace=.02 ) f.add_subplot(gs[0, 1:], facecolor=colors[2]) f.add_subplot(gs[1:, 0], facecolor=colors[1]) f.add_subplot(gs[1, 1], facecolor=colors[0]) return f def error_bars(): diamonds = sns.load_dataset("diamonds") with sns.axes_style("whitegrid"): g = sns.catplot( diamonds, x="carat", y="clarity", hue="clarity", kind="point", errorbar=("sd", .5), join=False, legend=False, facet_kws={"despine": False}, palette="ch:s=-.2,r=-.2,d=.4,l=.6_r", scale=.75, capsize=.3, ) g.ax.yaxis.set_inverted(False) return g.figure def properties(): f = mpl.figure.Figure(figsize=(5, 5)) x = np.arange(1, 11) y = np.zeros_like(x) p = so.Plot(x, y) ps = 14 plots = [ p.add(so.Dot(pointsize=ps), color=map(str, x)), p.add(so.Dot(color=".3", pointsize=ps), alpha=x), p.add(so.Dot(color=".9", pointsize=ps, edgewidth=2), edgecolor=x), p.add(so.Dot(color=".3"), pointsize=x).scale(pointsize=(4, 18)), p.add(so.Dot(pointsize=ps, color=".9", edgecolor=".2"), edgewidth=x), p.add(so.Dot(pointsize=ps, color=".3"), marker=map(str, x)), p.add(so.Dot(pointsize=ps, color=".3", marker="x"), stroke=x), ] with sns.axes_style("ticks"): axs = f.subplots(len(plots)) for p, ax in zip(plots, axs): p.on(ax).plot() ax.set(xticks=x, yticks=[], xticklabels=[], ylim=(-.2, .3)) sns.despine(ax=ax, left=True) f.legends = [] return f def objects_interface(): f = mpl.figure.Figure(figsize=(5, 4)) C = sns.color_palette("deep") ax = f.subplots() fontsize = 22 rects = [((.135, .50), .69), ((.275, .38), .26), ((.59, .38), .40)] for i, (xy, w) in enumerate(rects): ax.add_artist(mpl.patches.Rectangle(xy, w, .09, color=C[i], alpha=.2, lw=0)) ax.text(0, .52, "Plot(data, 'x', 'y', color='var1')", size=fontsize, color=".2") ax.text(0, .40, ".add(Dot(alpha=.5), marker='var2')", size=fontsize, color=".2") annots = [ ("Mapped\nin all layers", (.48, .62), (0, 55)), ("Set directly", (.41, .35), (0, -55)), ("Mapped\nin this layer", (.80, .35), (0, -55)), ] for i, (text, xy, xytext) in enumerate(annots): ax.annotate( text, xy, xytext, textcoords="offset points", fontsize=18, ha="center", va="center", arrowprops=dict(arrowstyle="->", linewidth=1.5, color=C[i]), color=C[i], ) ax.set_axis_off() f.subplots_adjust(0, 0, 1, 1) return f def relational(): mpg = sns.load_dataset("mpg") with sns.axes_style("ticks"): g = sns.relplot( data=mpg, x="horsepower", y="mpg", size="displacement", hue="weight", sizes=(50, 500), hue_norm=(2000, 4500), alpha=.75, legend=False, palette="ch:start=-.5,rot=.7,dark=.3,light=.7_r", ) g.figure.set_size_inches(5, 5) return g.figure def distributions(): penguins = sns.load_dataset("penguins").dropna() with sns.axes_style("white"): g = sns.displot( penguins, x="flipper_length_mm", row="island", binwidth=4, kde=True, line_kws=dict(linewidth=2), legend=False, ) sns.despine(left=True) g.figure.set_size_inches(5, 5) return g.figure def categorical(): penguins = sns.load_dataset("penguins").dropna() with sns.axes_style("whitegrid"): g = sns.catplot( penguins, x="sex", y="body_mass_g", hue="island", col="sex", kind="box", whis=np.inf, legend=False, sharex=False, ) sns.despine(left=True) g.figure.set_size_inches(5, 5) return g.figure def regression(): anscombe = sns.load_dataset("anscombe") with sns.axes_style("white"): g = sns.lmplot( anscombe, x="x", y="y", hue="dataset", col="dataset", col_wrap=2, scatter_kws=dict(edgecolor=".2", facecolor=".7", s=80), line_kws=dict(lw=4), ci=None, ) g.set(xlim=(2, None), ylim=(2, None)) g.figure.set_size_inches(5, 5) return g.figure def axis_grids(): penguins = sns.load_dataset("penguins").sample(200, random_state=0) with sns.axes_style("ticks"): g = sns.pairplot( penguins.drop("flipper_length_mm", axis=1), diag_kind="kde", diag_kws=dict(fill=False), plot_kws=dict(s=40, fc="none", ec="C0", alpha=.75, linewidth=.75), ) g.figure.set_size_inches(5, 5) return g.figure def aesthetics(): f = mpl.figure.Figure(figsize=(5, 5)) for i, style in enumerate(["darkgrid", "white", "ticks", "whitegrid"], 1): with sns.axes_style(style): ax = f.add_subplot(2, 2, i) ax.set(xticks=[0, .25, .5, .75, 1], yticks=[0, .25, .5, .75, 1]) sns.despine(ax=f.axes[1]) sns.despine(ax=f.axes[2]) return f def color_palettes(): f = mpl.figure.Figure(figsize=(5, 5)) palettes = ["deep", "husl", "gray", "ch:", "mako", "vlag", "icefire"] axs = f.subplots(len(palettes)) x = np.arange(10) for ax, name in zip(axs, palettes): cmap = mpl.colors.ListedColormap(sns.color_palette(name, x.size)) ax.pcolormesh(x[None, :], linewidth=.5, edgecolor="w", alpha=.8, cmap=cmap) ax.set_axis_off() return f def setup(app): app.connect("builder-inited", main) ================================================ FILE: doc/tools/extract_examples.py ================================================ """Turn the examples section of a function docstring into a notebook.""" import re import sys import pydoc import seaborn from seaborn.external.docscrape import NumpyDocString import nbformat def line_type(line): if line.startswith(" "): return "code" else: return "markdown" def add_cell(nb, lines, cell_type): cell_objs = { "code": nbformat.v4.new_code_cell, "markdown": nbformat.v4.new_markdown_cell, } text = "\n".join(lines) cell = cell_objs[cell_type](text) nb["cells"].append(cell) if __name__ == "__main__": _, name = sys.argv # Parse the docstring and get the examples section obj = getattr(seaborn, name) if obj.__class__.__name__ != "function": obj = obj.__init__ lines = NumpyDocString(pydoc.getdoc(obj))["Examples"] # Remove code indentation, the prompt, and mpl return variable pat = re.compile(r"\s{4}[>\.]{3} (ax = ){0,1}(g = ){0,1}") nb = nbformat.v4.new_notebook() # We always start with at least one line of text cell_type = "markdown" cell = [] for line in lines: # Ignore matplotlib plot directive if ".. plot" in line or ":context:" in line: continue # Ignore blank lines if not line: continue if line_type(line) != cell_type: # We are on the first line of the next cell, # so package up the last cell add_cell(nb, cell, cell_type) cell_type = line_type(line) cell = [] if line_type(line) == "code": line = re.sub(pat, "", line) cell.append(line) # Package the final cell add_cell(nb, cell, cell_type) nbformat.write(nb, f"docstrings/{name}.ipynb") ================================================ FILE: doc/tools/generate_logos.py ================================================ import numpy as np import seaborn as sns from matplotlib import patches import matplotlib.pyplot as plt from scipy.signal import gaussian from scipy.spatial import distance XY_CACHE = {} STATIC_DIR = "_static" plt.rcParams["savefig.dpi"] = 300 def poisson_disc_sample(array_radius, pad_radius, candidates=100, d=2, seed=None): """Find positions using poisson-disc sampling.""" # See http://bost.ocks.org/mike/algorithms/ rng = np.random.default_rng(seed) uniform = rng.uniform randint = rng.integers # Cache the results key = array_radius, pad_radius, seed if key in XY_CACHE: return XY_CACHE[key] # Start at a fixed point we know will work start = np.zeros(d) samples = [start] queue = [start] while queue: # Pick a sample to expand from s_idx = randint(len(queue)) s = queue[s_idx] for i in range(candidates): # Generate a candidate from this sample coords = uniform(s - 2 * pad_radius, s + 2 * pad_radius, d) # Check the three conditions to accept the candidate in_array = np.sqrt(np.sum(coords ** 2)) < array_radius in_ring = np.all(distance.cdist(samples, [coords]) > pad_radius) if in_array and in_ring: # Accept the candidate samples.append(coords) queue.append(coords) break if (i + 1) == candidates: # We've exhausted the particular sample queue.pop(s_idx) samples = np.array(samples) XY_CACHE[key] = samples return samples def logo( ax, color_kws, ring, ring_idx, edge, pdf_means, pdf_sigma, dy, y0, w, h, hist_mean, hist_sigma, hist_y0, lw, skip, scatter, pad, scale, ): # Square, invisible axes with specified limits to center the logo ax.set(xlim=(35 + w, 95 - w), ylim=(-3, 53)) ax.set_axis_off() ax.set_aspect('equal') # Magic numbers for the logo circle radius = 27 center = 65, 25 # Full x and y grids for a gaussian curve x = np.arange(101) y = gaussian(x.size, pdf_sigma) x0 = 30 # Magic number xx = x[x0:] # Vertical distances between the PDF curves n = len(pdf_means) dys = np.linspace(0, (n - 1) * dy, n) - (n * dy / 2) dys -= dys.mean() # Compute the PDF curves with vertical offsets pdfs = [h * (y[x0 - m:-m] + y0 + dy) for m, dy in zip(pdf_means, dys)] # Add in constants to fill from bottom and to top pdfs.insert(0, np.full(xx.shape, -h)) pdfs.append(np.full(xx.shape, 50 + h)) # Color gradient colors = sns.cubehelix_palette(n + 1 + bool(hist_mean), **color_kws) # White fill between curves and around edges bg = patches.Circle( center, radius=radius - 1 + ring, color="white", transform=ax.transData, zorder=0, ) ax.add_artist(bg) # Clipping artist (not shown) for the interior elements fg = patches.Circle(center, radius=radius - edge, transform=ax.transData) # Ring artist to surround the circle (optional) if ring: wedge = patches.Wedge( center, r=radius + edge / 2, theta1=0, theta2=360, width=edge / 2, transform=ax.transData, color=colors[ring_idx], alpha=1 ) ax.add_artist(wedge) # Add histogram bars if hist_mean: hist_color = colors.pop(0) hist_y = gaussian(x.size, hist_sigma) hist = 1.1 * h * (hist_y[x0 - hist_mean:-hist_mean] + hist_y0) dx = x[skip] - x[0] hist_x = xx[::skip] hist_h = h + hist[::skip] # Magic number to avoid tiny sliver of bar on edge use = hist_x < center[0] + radius * .5 bars = ax.bar( hist_x[use], hist_h[use], bottom=-h, width=dx, align="edge", color=hist_color, ec="w", lw=lw, zorder=3, ) for bar in bars: bar.set_clip_path(fg) # Add each smooth PDF "wave" for i, pdf in enumerate(pdfs[1:], 1): u = ax.fill_between(xx, pdfs[i - 1] + w, pdf, color=colors[i - 1], lw=0) u.set_clip_path(fg) # Add scatterplot in top wave area if scatter: seed = sum(map(ord, "seaborn logo")) xy = poisson_disc_sample(radius - edge - ring, pad, seed=seed) clearance = distance.cdist(xy + center, np.c_[xx, pdfs[-2]]) use = clearance.min(axis=1) > pad / 1.8 x, y = xy[use].T sizes = (x - y) % 9 points = ax.scatter( x + center[0], y + center[1], s=scale * (10 + sizes * 5), zorder=5, color=colors[-1], ec="w", lw=scale / 2, ) path = u.get_paths()[0] points.set_clip_path(path, transform=u.get_transform()) u.set_visible(False) def savefig(fig, shape, variant): fig.subplots_adjust(0, 0, 1, 1, 0, 0) facecolor = (1, 1, 1, 1) if bg == "white" else (1, 1, 1, 0) for ext in ["png", "svg"]: fig.savefig(f"{STATIC_DIR}/logo-{shape}-{variant}bg.{ext}", facecolor=facecolor) if __name__ == "__main__": for bg in ["white", "light", "dark"]: color_idx = -1 if bg == "dark" else 0 kwargs = dict( color_kws=dict(start=.3, rot=-.4, light=.8, dark=.3, reverse=True), ring=True, ring_idx=color_idx, edge=1, pdf_means=[8, 24], pdf_sigma=16, dy=1, y0=1.8, w=.5, h=12, hist_mean=2, hist_sigma=10, hist_y0=.6, lw=1, skip=6, scatter=True, pad=1.8, scale=.5, ) color = sns.cubehelix_palette(**kwargs["color_kws"])[color_idx] # ------------------------------------------------------------------------ # fig, ax = plt.subplots(figsize=(2, 2), facecolor="w", dpi=100) logo(ax, **kwargs) savefig(fig, "mark", bg) # ------------------------------------------------------------------------ # fig, axs = plt.subplots(1, 2, figsize=(8, 2), dpi=100, gridspec_kw=dict(width_ratios=[1, 3])) logo(axs[0], **kwargs) font = { "family": "avenir", "color": color, "weight": "regular", "size": 120, } axs[1].text(.01, .35, "seaborn", ha="left", va="center", fontdict=font, transform=axs[1].transAxes) axs[1].set_axis_off() savefig(fig, "wide", bg) # ------------------------------------------------------------------------ # fig, axs = plt.subplots(2, 1, figsize=(2, 2.5), dpi=100, gridspec_kw=dict(height_ratios=[4, 1])) logo(axs[0], **kwargs) font = { "family": "avenir", "color": color, "weight": "regular", "size": 34, } axs[1].text(.5, 1, "seaborn", ha="center", va="top", fontdict=font, transform=axs[1].transAxes) axs[1].set_axis_off() savefig(fig, "tall", bg) ================================================ FILE: doc/tools/nb_to_doc.py ================================================ #! /usr/bin/env python """Execute a .ipynb file, write out a processed .rst and clean .ipynb. Some functions in this script were copied from the nbstripout tool: Copyright (c) 2015 Min RK, Florian Rathgeber, Michael McNeil Forbes 2019 Casper da Costa-Luis Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ import os import sys import nbformat from nbconvert import RSTExporter from nbconvert.preprocessors import ( ExecutePreprocessor, TagRemovePreprocessor, ExtractOutputPreprocessor ) from traitlets.config import Config class MetadataError(Exception): pass def pop_recursive(d, key, default=None): """dict.pop(key) where `key` is a `.`-delimited list of nested keys. >>> d = {'a': {'b': 1, 'c': 2}} >>> pop_recursive(d, 'a.c') 2 >>> d {'a': {'b': 1}} """ nested = key.split('.') current = d for k in nested[:-1]: if hasattr(current, 'get'): current = current.get(k, {}) else: return default if not hasattr(current, 'pop'): return default return current.pop(nested[-1], default) def strip_output(nb): """ Strip the outputs, execution count/prompt number and miscellaneous metadata from a notebook object, unless specified to keep either the outputs or counts. """ keys = {'metadata': [], 'cell': {'metadata': ["execution"]}} nb.metadata.pop('signature', None) nb.metadata.pop('widgets', None) for field in keys['metadata']: pop_recursive(nb.metadata, field) if 'NB_KERNEL' in os.environ: nb.metadata['kernelspec']['name'] = os.environ['NB_KERNEL'] nb.metadata['kernelspec']['display_name'] = os.environ['NB_KERNEL'] for cell in nb.cells: if 'outputs' in cell: cell['outputs'] = [] if 'prompt_number' in cell: cell['prompt_number'] = None if 'execution_count' in cell: cell['execution_count'] = None # Always remove this metadata for output_style in ['collapsed', 'scrolled']: if output_style in cell.metadata: cell.metadata[output_style] = False if 'metadata' in cell: for field in ['collapsed', 'scrolled', 'ExecuteTime']: cell.metadata.pop(field, None) for (extra, fields) in keys['cell'].items(): if extra in cell: for field in fields: pop_recursive(getattr(cell, extra), field) return nb if __name__ == "__main__": # Get the desired ipynb file path and parse into components _, fpath, outdir = sys.argv basedir, fname = os.path.split(fpath) fstem = fname[:-6] # Read the notebook with open(fpath) as f: nb = nbformat.read(f, as_version=4) # Run the notebook kernel = os.environ.get("NB_KERNEL", None) if kernel is None: kernel = nb["metadata"]["kernelspec"]["name"] ep = ExecutePreprocessor( timeout=600, kernel_name=kernel, extra_arguments=["--InlineBackend.rc=figure.dpi=88"] ) ep.preprocess(nb, {"metadata": {"path": basedir}}) # Remove plain text execution result outputs for cell in nb.get("cells", {}): if "show-output" in cell["metadata"].get("tags", []): continue fields = cell.get("outputs", []) for field in fields: if field["output_type"] == "execute_result": data_keys = field["data"].keys() for key in list(data_keys): if key == "text/plain": field["data"].pop(key) if not field["data"]: fields.remove(field) # Convert to .rst formats exp = RSTExporter() c = Config() c.TagRemovePreprocessor.remove_cell_tags = {"hide"} c.TagRemovePreprocessor.remove_input_tags = {"hide-input"} c.TagRemovePreprocessor.remove_all_outputs_tags = {"hide-output"} c.ExtractOutputPreprocessor.output_filename_template = \ f"{fstem}_files/{fstem}_" + "{cell_index}_{index}{extension}" exp.register_preprocessor(TagRemovePreprocessor(config=c), True) exp.register_preprocessor(ExtractOutputPreprocessor(config=c), True) body, resources = exp.from_notebook_node(nb) # Clean the output on the notebook and save a .ipynb back to disk nb = strip_output(nb) with open(fpath, "wt") as f: nbformat.write(nb, f) # Write the .rst file rst_path = os.path.join(outdir, f"{fstem}.rst") with open(rst_path, "w") as f: f.write(body) # Write the individual image outputs imdir = os.path.join(outdir, f"{fstem}_files") if not os.path.exists(imdir): os.mkdir(imdir) for imname, imdata in resources["outputs"].items(): if imname.startswith(fstem): impath = os.path.join(outdir, f"{imname}") with open(impath, "wb") as f: f.write(imdata) ================================================ FILE: doc/tools/set_nb_kernels.py ================================================ """Recursively set the kernel name for all jupyter notebook files.""" import sys from glob import glob import nbformat if __name__ == "__main__": _, kernel_name = sys.argv nb_paths = glob("./**/*.ipynb", recursive=True) for path in nb_paths: with open(path) as f: nb = nbformat.read(f, as_version=4) nb["metadata"]["kernelspec"]["name"] = kernel_name nb["metadata"]["kernelspec"]["display_name"] = kernel_name with open(path, "w") as f: nbformat.write(nb, f) ================================================ FILE: doc/tutorial.yaml ================================================ - title: pages: - introduction - title: API Overview pages: - function_overview - data_structure - title: Objects interface pages: - objects_interface - properties - title: Plotting functions pages: - relational - distributions - categorical - title: Statistical operations pages: - error_bars - regression - title: Multi-plot grids pages: - axis_grids - title: Figure aesthetics pages: - aesthetics - color_palettes ================================================ FILE: doc/whatsnew/index.rst ================================================ .. _whatsnew: What's new in each version ========================== v0.13 ----- .. toctree:: :maxdepth: 2 v0.13.2 v0.13.1 v0.13.0 v0.12 ----- .. toctree:: :maxdepth: 2 v0.12.2 v0.12.1 v0.12.0 v0.11 ----- .. toctree:: :maxdepth: 2 v0.11.2 v0.11.1 v0.11.0 v0.10 ----- .. toctree:: :maxdepth: 2 v0.10.1 v0.10.0 v0.9 ---- .. toctree:: :maxdepth: 2 v0.9.1 v0.9.0 v0.8 ---- .. toctree:: :maxdepth: 2 v0.8.1 v0.8.0 v0.7 ---- .. toctree:: :maxdepth: 2 v0.7.1 v0.7.0 v0.6 ---- .. toctree:: :maxdepth: 2 v0.6.0 v0.5 ---- .. toctree:: :maxdepth: 2 v0.5.1 v0.5.0 v0.4 ---- .. toctree:: :maxdepth: 2 v0.4.0 v0.3 ---- .. toctree:: :maxdepth: 2 v0.3.1 v0.3.0 v0.2 ---- .. toctree:: :maxdepth: 2 v0.2.1 v0.2.0 ================================================ FILE: doc/whatsnew/v0.10.0.rst ================================================ v0.10.0 (January 2020) ---------------------- This is a major update that is being released simultaneously with version 0.9.1. It has all of the same features (and bugs!) as 0.9.1, but there are important changes to the dependencies. Most notably, all support for Python 2 has now been dropped. Support for Python 3.5 has also been dropped. Seaborn is now strictly compatible with Python 3.6+. Minimally supported versions of the dependent PyData libraries have also been increased, in some cases substantially. While seaborn has tended to be very conservative about maintaining compatibility with older dependencies, this was causing increasing pain during development. At the same time, these libraries are now much easier to install. Going forward, seaborn will likely stay close to the `Numpy community guidelines `_ for version support. This release also removes a few previously-deprecated features: - The ``tsplot`` function and ``seaborn.timeseries`` module have been removed. Recall that ``tsplot`` was replaced with :func:`lineplot`. - The ``seaborn.apionly`` entry-point has been removed. - The ``seaborn.linearmodels`` module (previously renamed to ``seaborn.regression``) has been removed. ================================================ FILE: doc/whatsnew/v0.10.1.rst ================================================ v0.10.1 (April 2020) -------------------- This is minor release with bug fixes for issues identified since 0.10.0. - Fixed a bug that appeared within the bootstrapping algorithm on 32-bit systems. - Fixed a bug where :func:`regplot` would crash on singleton inputs. Now a crash is avoided and regression estimation/plotting is skipped. - Fixed a bug where :func:`heatmap` would ignore user-specified under/over/bad values when recentering a colormap. - Fixed a bug where :func:`heatmap` would use values from masked cells when computing default colormap limits. - Fixed a bug where :func:`despine` would cause an error when trying to trim spines on a matplotlib categorical axis. - Adapted to a change in matplotlib that caused problems with single swarm plots. - Added the ``showfliers`` parameter to :func:`boxenplot` to suppress plotting of outlier data points, matching the API of :func:`boxplot`. - Avoided seeing an error from statmodels when data with an IQR of 0 is passed to :func:`kdeplot`. - Added the ``legend.title_fontsize`` to the :func:`plotting_context` definition. - Deprecated several utility functions that are no longer used internally (``percentiles``, ``sig_stars``, ``pmf_hist``, and ``sort_df``). ================================================ FILE: doc/whatsnew/v0.11.0.rst ================================================ v0.11.0 (September 2020) ------------------------ This is a major release with several important new features, enhancements to existing functions, and changes to the library. Highlights include an overhaul and modernization of the distributions plotting functions, more flexible data specification, new colormaps, and better narrative documentation. For an overview of the new features and a guide to updating, see `this Medium post `_. Required keyword arguments ~~~~~~~~~~~~~~~~~~~~~~~~~~ |API| Most plotting functions now require all of their parameters to be specified using keyword arguments. To ease adaptation, code without keyword arguments will trigger a ``FutureWarning`` in v0.11. In a future release (v0.12 or v0.13, depending on release cadence), this will become an error. Once keyword arguments are fully enforced, the signature of the plotting functions will be reorganized to accept ``data`` as the first and only positional argument (:pr:`2052,2081`). Modernization of distribution functions ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The distribution module has been completely overhauled, modernizing the API and introducing several new functions and features within existing functions. Some new features are explained here; the :doc:`tutorial documentation
` has also been rewritten and serves as a good introduction to the functions. New plotting functions ^^^^^^^^^^^^^^^^^^^^^^ |Feature| |Enhancement| First, three new functions, :func:`displot`, :func:`histplot` and :func:`ecdfplot` have been added (:pr:`2157`, :pr:`2125`, :pr:`2141`). The figure-level :func:`displot` function is an interface to the various distribution plots (analogous to :func:`relplot` or :func:`catplot`). It can draw univariate or bivariate histograms, density curves, ECDFs, and rug plots on a :class:`FacetGrid`. The axes-level :func:`histplot` function draws univariate or bivariate histograms with a number of features, including: - mapping multiple distributions with a ``hue`` semantic - normalization to show density, probability, or frequency statistics - flexible parameterization of bin size, including proper bins for discrete variables - adding a KDE fit to show a smoothed distribution over all bin statistics - experimental support for histograms over categorical and datetime variables. The axes-level :func:`ecdfplot` function draws univariate empirical cumulative distribution functions, using a similar interface. Changes to existing functions ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |API| |Feature| |Enhancement| |Defaults| Second, the existing functions :func:`kdeplot` and :func:`rugplot` have been completely overhauled (:pr:`2060,2104`). The overhauled functions now share a common API with the rest of seaborn, they can show conditional distributions by mapping a third variable with a ``hue`` semantic, and they have been improved in numerous other ways. The github pull request (:pr:`2104`) has a longer explanation of the changes and the motivation behind them. This is a necessarily API-breaking change. The parameter names for the positional variables are now ``x`` and ``y``, and the old names have been deprecated. Efforts were made to handle and warn when using the deprecated API, but it is strongly suggested to check your plots carefully. Additionally, the statsmodels-based computation of the KDE has been removed. Because there were some inconsistencies between the way different parameters (specifically, ``bw``, ``clip``, and ``cut``) were implemented by each backend, this may cause plots to look different with non-default parameters. Support for using non-Gaussian kernels, which was available only in the statsmodels backend, has been removed. Other new features include: - several options for representing multiple densities (using the ``multiple`` and ``common_norm`` parameters) - weighted density estimation (using the new ``weights`` parameter) - better control over the smoothing bandwidth (using the new ``bw_adjust`` parameter) - more meaningful parameterization of the contours that represent a bivariate density (using the ``thresh`` and ``levels`` parameters) - log-space density estimation (using the new ``log_scale`` parameter, or by scaling the data axis before plotting) - "bivariate" rug plots with a single function call (by assigning both ``x`` and ``y``) Deprecations ^^^^^^^^^^^^ |API| Finally, the :func:`distplot` function is now formally deprecated. Its features have been subsumed by :func:`displot` and :func:`histplot`. Some effort was made to gradually transition :func:`distplot` by adding the features in :func:`displot` and handling backwards compatibility, but this proved to be too difficult. The similarity in the names will likely cause some confusion during the transition, which is regrettable. Related enhancements and changes ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |API| |Feature| |Enhancement| |Defaults| These additions facilitated new features (and forced changes) in :func:`jointplot` and :class:`JointGrid` (:pr:`2210`) and in :func:`pairplot` and :class:`PairGrid` (:pr:`2234`). - Added support for the ``hue`` semantic in :func:`jointplot`/:class:`JointGrid`. This support is lightweight and simply delegates the mapping to the underlying axes-level functions. - Delegated the handling of ``hue`` in :class:`PairGrid`/:func:`pairplot` to the plotting function when it understands ``hue``, meaning that (1) the zorder of scatterplot points will be determined by row in dataframe, (2) additional options for resolving hue (e.g. the ``multiple`` parameter) can be used, and (3) numeric hue variables can be naturally mapped when using :func:`scatterplot`. - Added ``kind="hist"`` to :func:`jointplot`, which draws a bivariate histogram on the joint axes and univariate histograms on the marginal axes, as well as both ``kind="hist"`` and ``kind="kde"`` to :func:`pairplot`, which behaves likewise. - The various modes of :func:`jointplot` that plot marginal histograms now use :func:`histplot` rather than :func:`distplot`. This slightly changes the default appearance and affects the valid keyword arguments that can be passed to customize the plot. Likewise, the marginal histogram plots in :func:`pairplot` now use :func:`histplot`. Standardization and enhancements of data ingest ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |Feature| |Enhancement| |Docs| The code that processes input data has been refactored and enhanced. In v0.11, this new code takes effect for the relational and distribution modules; other modules will be refactored to use it in future releases (:pr:`2071`). These changes should be transparent for most use-cases, although they allow a few new features: - Named variables for long-form data can refer to the named index of a :class:`pandas.DataFrame` or to levels in the case of a multi-index. Previously, it was necessary to call :meth:`pandas.DataFrame.reset_index` before using index variables (e.g., after a groupby operation). - :func:`relplot` now has the same flexibility as the axes-level functions to accept data in long- or wide-format and to accept data vectors (rather than named variables) in long-form mode. - The data parameter can now be a Python ``dict`` or an object that implements that interface. This is a new feature for wide-form data. For long-form data, it was previously supported but not documented. - A wide-form data object can have a mixture of types; the non-numeric types will be removed before plotting. Previously, this caused an error. - There are better error messages for other instances of data mis-specification. See the new user guide chapter on :doc:`data formats
` for more information about what is supported. Other changes ~~~~~~~~~~~~~ Documentation improvements ^^^^^^^^^^^^^^^^^^^^^^^^^^ - |Docs| Added two new chapters to the user guide, one giving an overview of the :doc:`types of functions in seaborn
`, and one discussing the different :doc:`data formats
` that seaborn understands. - |Docs| Expanded the :doc:`color palette tutorial
` to give more background on color theory and better motivate the use of color in statistical graphics. - |Docs| Added more information to the :doc:`installation guidelines
` and streamlined the :doc:`introduction
` page. - |Docs| Improved cross-linking within the seaborn docs and between the seaborn and matplotlib docs. Theming ^^^^^^^ - |API| The :func:`set` function has been renamed to :func:`set_theme` for more clarity about what it does. For the foreseeable future, :func:`set` will remain as an alias, but it is recommended to update your code. Relational plots ^^^^^^^^^^^^^^^^ - |Enhancement| |Defaults| Reduced some of the surprising behavior of relational plot legends when using a numeric hue or size mapping (:pr:`2229`): - Added an "auto" mode (the new default) that chooses between "brief" and "full" legends based on the number of unique levels of each variable. - Modified the ticking algorithm for a "brief" legend to show up to 6 values and not to show values outside the limits of the data. - Changed the approach to the legend title: the normal matplotlib legend title is used when only one variable is assigned a semantic mapping, whereas the old approach of adding an invisible legend artist with a subtitle label is used only when multiple semantic variables are defined. - Modified legend subtitles to be left-aligned and to be drawn in the default legend title font size. - |Enhancement| |Defaults| Changed how functions that use different representations for numeric and categorical data handle vectors with an ``object`` data type. Previously, data was considered numeric if it could be coerced to a float representation without error. Now, object-typed vectors are considered numeric only when their contents are themselves numeric. As a consequence, numbers that are encoded as strings will now be treated as categorical data (:pr:`2084`). - |Enhancement| |Defaults| Plots with a ``style`` semantic can now generate an infinite number of unique dashes and/or markers by default. Previously, an error would be raised if the ``style`` variable had more levels than could be mapped using the default lists. The existing defaults were slightly modified as part of this change; if you need to exactly reproduce plots from earlier versions, refer to the `old defaults `_ (:pr:`2075`). - |Defaults| Changed how :func:`scatterplot` sets the default linewidth for the edges of the scatter points. New behavior is to scale with the point sizes themselves (on a plot-wise, not point-wise basis). This change also slightly reduces the default width when point sizes are not varied. Set ``linewidth=0.75`` to reproduce the previous behavior. (:pr:`2708`). - |Enhancement| Improved support for datetime variables in :func:`scatterplot` and :func:`lineplot` (:pr:`2138`). - |Fix| Fixed a bug where :func:`lineplot` did not pass the ``linestyle`` parameter down to matplotlib (:pr:`2095`). - |Fix| Adapted to a change in matplotlib that prevented passing vectors of literal values to ``c`` and ``s`` in :func:`scatterplot` (:pr:`2079`). Categorical plots ^^^^^^^^^^^^^^^^^ - |Enhancement| |Defaults| |Fix| Fixed a few computational issues in :func:`boxenplot` and improved its visual appearance (:pr:`2086`): - Changed the default method for computing the number of boxes to``k_depth="tukey"``, as the previous default (``k_depth="proportion"``) is based on a heuristic that produces too many boxes for small datasets. - Added the option to specify the specific number of boxes (e.g. ``k_depth=6``) or to plot boxes that will cover most of the data points (``k_depth="full"``). - Added a new parameter, ``trust_alpha``, to control the number of boxes when ``k_depth="trustworthy"``. - Changed the visual appearance of :func:`boxenplot` to more closely resemble :func:`boxplot`. Notably, thin boxes will remain visible when the edges are white. - |Enhancement| Allowed :func:`catplot` to use different values on the categorical axis of each facet when axis sharing is turned off (e.g. by specifying ``sharex=False``) (:pr:`2196`). - |Enhancement| Improved the error messages produced when categorical plots process the orientation parameter. - |Enhancement| Added an explicit warning in :func:`swarmplot` when more than 5% of the points overlap in the "gutters" of the swarm (:pr:`2045`). Multi-plot grids ^^^^^^^^^^^^^^^^ - |Feature| |Enhancement| |Defaults| A few small changes to make life easier when using :class:`PairGrid` (:pr:`2234`): - Added public access to the legend object through the ``legend`` attribute (also affects :class:`FacetGrid`). - The ``color`` and ``label`` parameters are no longer passed to the plotting functions when ``hue`` is not used. - The data is no longer converted to a numpy object before plotting on the marginal axes. - It is possible to specify only one of ``x_vars`` or ``y_vars``, using all variables for the unspecified dimension. - The ``layout_pad`` parameter is stored and used every time you call the :meth:`PairGrid.tight_layout` method. - |Feature| Added a ``tight_layout`` method to :class:`FacetGrid` and :class:`PairGrid`, which runs the :func:`matplotlib.pyplot.tight_layout` algorithm without interference from the external legend (:pr:`2073`). - |Feature| Added the ``axes_dict`` attribute to :class:`FacetGrid` for named access to the component axes (:pr:`2046`). - |Enhancement| Made :meth:`FacetGrid.set_axis_labels` clear labels from "interior" axes (:pr:`2046`). - |Feature| Added the ``marginal_ticks`` parameter to :class:`JointGrid` which, if set to ``True``, will show ticks on the count/density axis of the marginal plots (:pr:`2210`). - |Enhancement| Improved :meth:`FacetGrid.set_titles` with ``margin_titles=True``, such that texts representing the original row titles are removed before adding new ones (:pr:`2083`). - |Defaults| Changed the default value for ``dropna`` to ``False`` in :class:`FacetGrid`, :class:`PairGrid`, :class:`JointGrid`, and corresponding functions. As all or nearly all seaborn and matplotlib plotting functions handle missing data well, this option is no longer useful, but it causes problems in some edge cases. It may be deprecated in the future. (:pr:`2204`). - |Fix| Fixed a bug in :class:`PairGrid` that appeared when setting ``corner=True`` and ``despine=False`` (:pr:`2203`). Color palettes ~~~~~~~~~~~~~~ - |Docs| Improved and modernized the :doc:`color palettes chapter
` of the seaborn tutorial. - |Feature| Added two new perceptually-uniform colormaps: "flare" and "crest". The new colormaps are similar to "rocket" and "mako", but their luminance range is reduced. This makes them well suited to numeric mappings of line or scatter plots, which need contrast with the axes background at the extremes (:pr:`2237`). - |Enhancement| |Defaults| Enhanced numeric colormap functionality in several ways (:pr:`2237`): - Added string-based access within the :func:`color_palette` interface to :func:`dark_palette`, :func:`light_palette`, and :func:`blend_palette`. This means that anywhere you specify a palette in seaborn, a name like ``"dark:blue"`` will use :func:`dark_palette` with the input ``"blue"``. - Added the ``as_cmap`` parameter to :func:`color_palette` and changed internal code that uses a continuous colormap to take this route. - Tweaked the :func:`light_palette` and :func:`dark_palette` functions to use an endpoint that is a very desaturated version of the input color, rather than a pure gray. This produces smoother ramps. To exactly reproduce previous plots, use :func:`blend_palette` with ``".13"`` for dark or ``".95"`` for light. - Changed :func:`diverging_palette` to have a default value of ``sep=1``, which gives better results. - |Enhancement| Added a rich HTML representation to the object returned by :func:`color_palette` (:pr:`2225`). - |Fix| Fixed the ``"{palette}_d"`` logic to modify reversed colormaps and to use the correct direction of the luminance ramp in both cases. Deprecations and removals ^^^^^^^^^^^^^^^^^^^^^^^^^ - |Enhancement| Removed an optional (and undocumented) dependency on BeautifulSoup (:pr:`2190`) in :func:`get_dataset_names`. - |API| Deprecated the ``axlabel`` function; use ``ax.set(xlabel=, ylabel=)`` instead. - |API| Deprecated the ``iqr`` function; use :func:`scipy.stats.iqr` instead. - |API| Final removal of the previously-deprecated ``annotate`` method on :class:`JointGrid`, along with related parameters. - |API| Final removal of the ``lvplot`` function (the previously-deprecated name for :func:`boxenplot`). ================================================ FILE: doc/whatsnew/v0.11.1.rst ================================================ v0.11.1 (December 2020) ----------------------- This a bug fix release and is a recommended upgrade for all users on v0.11.0. - |Enhancement| Reduced the use of matplotlib global state in the :ref:`multi-grid classes ` (:pr:`2388`). - |Fix| Restored support for using tuples or numeric keys to reference fields in a long-form `data` object (:pr:`2386`). - |Fix| Fixed a bug in :func:`lineplot` where NAs were propagating into the confidence interval, sometimes erasing it from the plot (:pr:`2273`). - |Fix| Fixed a bug in :class:`PairGrid`/:func:`pairplot` where diagonal axes would be empty when the grid was not square and the diagonal axes did not contain the marginal plots (:pr:`2270`). - |Fix| Fixed a bug in :class:`PairGrid`/:func:`pairplot` where off-diagonal plots would not appear when column names in `data` had non-string type (:pr:`2368`). - |Fix| Fixed a bug where categorical dtype information was ignored when data consisted of boolean or boolean-like values (:pr:`2379`). - |Fix| Fixed a bug in :class:`FacetGrid` where interior tick labels would be hidden when only the orthogonal axis was shared (:pr:`2347`). - |Fix| Fixed a bug in :class:`FacetGrid` that caused an error when `legend_out=False` was set (:pr:`2304`). - |Fix| Fixed a bug in :func:`kdeplot` where ``common_norm=True`` was ignored if ``hue`` was not assigned (:pr:`2378`). - |Fix| Fixed a bug in :func:`displot` where the ``row_order`` and ``col_order`` parameters were not used (:pr:`2262`). - |Fix| Fixed a bug in :class:`PairGrid`/:func:`pairplot` that caused an exception when using `corner=True` and `diag_kind=None` (:pr:`2382`). - |Fix| Fixed a bug in :func:`clustermap` where `annot=False` was ignored (:pr:`2323`). - |Fix| Fixed a bug in :func:`clustermap` where row/col color annotations could not have a categorical dtype (:pr:`2389`). - |Fix| Fixed a bug in :func:`boxenplot` where the `linewidth` parameter was ignored (:pr:`2287`). - |Fix| Raise a more informative error in :class:`PairGrid`/:func:`pairplot` when no variables can be found to define the rows/columns of the grid (:pr:`2382`). - |Fix| Raise a more informative error from :func:`clustermap` if row/col color objects have semantic index but data object does not (:pr:`2313`). ================================================ FILE: doc/whatsnew/v0.11.2.rst ================================================ v0.11.2 (August 2021) --------------------- This is a minor release that addresses issues in the v0.11 series and adds a small number of targeted enhancements. It is a recommended upgrade for all users. - |API| |Enhancement| In :func:`lmplot`, added a new `facet_kws` parameter and deprecated the `sharex`, `sharey`, and `legend_out` parameters from the function signature; pass them in a `facet_kws` dictionary instead (:pr:`2576`). - |Feature| Added a :func:`move_legend` convenience function for repositioning the legend on an existing axes or figure, along with updating its properties. This function should be preferred over calling `ax.legend` with no legend data, which does not reliably work across seaborn plot types (:pr:`2643`). - |Feature| In :func:`histplot`, added `stat="percent"` as an option for normalization such that bar heights sum to 100 and `stat="proportion"` as an alias for the existing `stat="probability"` (:pr:`2461`, :pr:`2634`). - |Feature| Added :meth:`FacetGrid.refline` and :meth:`JointGrid.refline` methods for plotting horizontal and/or vertical reference lines on every subplot in one step (:pr:`2620`). - |Feature| In :func:`kdeplot`, added a `warn_singular` parameter to silence the warning about data with zero variance (:pr:`2566`). - |Enhancement| In :func:`histplot`, improved performance with large datasets and many groupings/facets (:pr:`2559`, :pr:`2570`). - |Enhancement| The :class:`FacetGrid`, :class:`PairGrid`, and :class:`JointGrid` objects now reference the underlying matplotlib figure with a `.figure` attribute. The existing `.fig` attribute still exists but is discouraged and may eventually be deprecated. The effect is that you can now call `obj.figure` on the return value from any seaborn function to access the matplotlib object (:pr:`2639`). - |Enhancement| In :class:`FacetGrid` and functions that use it, visibility of the interior axis labels is now disabled, and exterior axis labels are no longer erased when adding additional layers. This produces the same results for plots made by seaborn functions, but it may produce different (better, in most cases) results for customized facet plots (:pr:`2583`). - |Enhancement| In :class:`FacetGrid`, :class:`PairGrid`, and functions that use them, the matplotlib `figure.autolayout` parameter is disabled to avoid having the legend overlap the plot (:pr:`2571`). - |Enhancement| The :func:`load_dataset` helper now produces a more informative error when fed a dataframe, easing a common beginner mistake (:pr:`2604`). - |Fix| |Enhancement| Improved robustness to missing data, including some additional support for the `pd.NA` type (:pr:`2417`, :pr:`2435`). - |Fix| In :func:`ecdfplot` and :func:`rugplot`, fixed a bug where results were incorrect if the data axis had a log scale before plotting (:pr:`2504`). - |Fix| In :func:`histplot`, fixed a bug where using `shrink` with non-discrete bins shifted bar positions inaccurately (:pr:`2477`). - |Fix| In :func:`displot`, fixed a bug where `common_norm=False` was ignored when faceting was used without assigning `hue` (:pr:`2468`). - |Fix| In :func:`histplot`, fixed two bugs where automatically computed edge widths were too thick for log-scaled histograms and for categorical histograms on the y axis (:pr:`2522`). - |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `alpha` parameter was ignored when `fill=False` (:pr:`2460`). - |Fix| In :func:`histplot` and :func:`kdeplot`, fixed a bug where the `multiple` parameter was ignored when `hue` was provided as a vector without a name (:pr:`2462`). - |Fix| In :func:`displot`, the default alpha value now adjusts to a provided `multiple` parameter even when `hue` is not assigned (:pr:`2462`). - |Fix| In :func:`displot`, fixed a bug that caused faceted 2D histograms to error out with `common_bins=False` (:pr:`2640`). - |Fix| In :func:`rugplot`, fixed a bug that prevented the use of datetime data (:pr:`2458`). - |Fix| In :func:`relplot` and :func:`displot`, fixed a bug where the dataframe attached to the returned `FacetGrid` object dropped columns that were not used in the plot (:pr:`2623`). - |Fix| In :func:`relplot`, fixed an error that would be raised when one of the column names in the dataframe shared a name with one of the plot variables (:pr:`2581`). - |Fix| In the relational plots, fixed a bug where legend entries for the `size` semantic were incorrect when `size_norm` extrapolated beyond the range of the data (:pr:`2580`). - |Fix| In :func:`lmplot` and :func:`regplot`, fixed a bug where the x axis was clamped to the data limits with `truncate=True` (:pr:`2576`). - |Fix| In :func:`lmplot`, fixed a bug where `sharey=False` did not always work as expected (:pr:`2576`). - |Fix| In :func:`heatmap`, fixed a bug where vertically-rotated y-axis tick labels would be misaligned with their rows (:pr:`2574`). - |Fix| Fixed an issue that prevented Python from running in `-OO` mode while using seaborn (:pr:`2473`). - |Docs| Improved the API documentation for theme-related functions (:pr:`2573`). - |Docs| Added docstring pages for all methods on documented classes (:pr:`2644`). ================================================ FILE: doc/whatsnew/v0.12.0.rst ================================================ v0.12.0 (September 2022) ------------------------ Introduction of the objects interface ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This release debuts the `seaborn.objects` interface, an entirely new approach to making plots with seaborn. It is the product of several years of design and 16 months of implementation work. The interface aims to provide a more declarative, composable, and extensible API for making statistical graphics. It is inspired by Wilkinson's grammar of graphics, offering a Pythonic API that is informed by the design of libraries such as `ggplot2` and `vega-lite` along with lessons from the past 10 years of seaborn's development. For more information and numerous examples, see the :doc:`tutorial chapter ` and :ref:`API reference ` This initial release should be considered "experimental". While it is stable enough for serious use, there are definitely some rough edges, and some key features remain to be implemented. It is possible that breaking changes may occur over the next few minor releases. Please be patient with any limitations that you encounter and help the development by reporting issues when you find behavior surprising. Keyword-only arguments ~~~~~~~~~~~~~~~~~~~~~~ |API| Seaborn's plotting functions now require explicit keywords for most arguments, following the deprecation of positional arguments in v0.11.0. With this enforcement, most functions have also had their parameter lists rearranged so that `data` is the first and only positional argument. This adds consistency across the various functions in the library. It also means that calling `func(data)` will do something for nearly all functions (those that support wide-form data) and that :class:`pandas.DataFrame` can be piped directly into a plot. It is possible that the signatures will be loosened a bit in future releases so that `x` and `y` can be positional, but minimal support for positional arguments after this change will reduce the chance of inadvertent mis-specification (:pr:`2804`). Modernization of categorical scatterplots ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ This release begins the process of modernizing the :ref:`categorical plots `, beginning with :func:`stripplot` and :func:`swarmplot`. These functions are sporting some enhancements that alleviate a few long-running frustrations (:pr:`2413`, :pr:`2447`): - |Feature| The new `native_scale` parameter allows numeric or datetime categories to be plotted with their original scale rather than converted to strings and plotted at fixed intervals. - |Feature| The new `formatter` parameter allows more control over the string representation of values on the categorical axis. There should also be improved defaults for some types, such as dates. - |Enhancement| It is now possible to assign `hue` when using only one coordinate variable (i.e. only `x` or `y`). - |Enhancement| It is now possible to disable the legend. The updates also harmonize behavior with functions that have been more recently introduced. This should be relatively non-disruptive, although a few defaults will change: - |Defaults| The functions now hook into matplotlib's unit system for plotting categorical data. (Seaborn's categorical functions actually predate support for categorical data in matplotlib.) This should mostly be transparent to the user, but it may resolve a few edge cases. For example, matplotlib interactivity should work better (e.g., for showing the data value under the cursor). - |Defaults| A color palette is no longer applied to levels of the categorical variable by default. It is now necessary to explicitly assign `hue` to see multiple colors (i.e., assign the same variable to `x`/`y` and `hue`). Passing `palette` without `hue` will continue to be honored for one release cycle. - |Defaults| Numeric `hue` variables now receive a continuous mapping by default, using the same rules as :func:`scatterplot`. Pass `palette="deep"` to reproduce previous defaults. - |Defaults| The plots now follow the default property cycle; i.e. calling an axes-level function multiple times with the same active axes will produce different-colored artists. - |API| Currently, assigning `hue` and then passing a `color` will produce a gradient palette. This is now deprecated, as it is easy to request a gradient with, e.g. `palette="light:blue"`. Similar enhancements / updates should be expected to roll out to other categorical plotting functions in future releases. There are also several function-specific enhancements: - |Enhancement| In :func:`stripplot`, a "strip" with a single observation will be plotted without jitter (:pr:`2413`) - |Enhancement| In :func:`swarmplot`, the points are now swarmed at draw time, meaning that the plot will adapt to further changes in axis scaling or tweaks to the plot layout (:pr:`2443`). - |Feature| In :func:`swarmplot`, the proportion of points that must overlap before issuing a warning can now be controlled with the `warn_thresh` parameter (:pr:`2447`). - |Fix| In :func:`swarmplot`, the order of the points in each swarm now matches the order in the original dataset; previously they were sorted. This affects only the underlying data stored in the matplotlib artist, not the visual representation (:pr:`2443`). More flexible errorbars ~~~~~~~~~~~~~~~~~~~~~~~ |API| |Feature| Increased the flexibility of what can be shown by the internally-calculated errorbars for :func:`lineplot`, :func:`barplot`, and :func:`pointplot`. With the new `errorbar` parameter, it is now possible to select bootstrap confidence intervals, percentile / predictive intervals, or intervals formed by scaled standard deviations or standard errors. The parameter also accepts an arbitrary function that maps from a vector to an interval. There is a new :doc:`user guide chapter ` demonstrating these options and explaining when you might want to use each one. As a consequence of this change, the `ci` parameter has been deprecated. Note that :func:`regplot` retains the previous API, but it will likely be updated in a future release (:pr:`2407`, :pr:`2866`). Other updates ~~~~~~~~~~~~~ - |Feature| It is now possible to aggregate / sort a :func:`lineplot` along the y axis using `orient="y"` (:pr:`2854`). - |Feature| Made it easier to customize :class:`FacetGrid` / :class:`PairGrid` / :class:`JointGrid` with a fluent (method-chained) style by adding `apply`/ `pipe` methods. Additionally, fixed the `tight_layout` and `refline` methods so that they return `self` (:pr:`2926`). - |Feature| Added :meth:`FacetGrid.tick_params` and :meth:`PairGrid.tick_params` to customize the appearance of the ticks, tick labels, and gridlines of all subplots at once (:pr:`2944`). - |Enhancement| Added a `width` parameter to :func:`barplot` (:pr:`2860`). - |Enhancement| It is now possible to specify `estimator` as a string in :func:`barplot` and :func:`pointplot`, in addition to a callable (:pr:`2866`). - |Enhancement| Error bars in :func:`regplot` now inherit the alpha value of the points they correspond to (:pr:`2540`). - |Enhancement| When using :func:`pairplot` with `corner=True` and `diag_kind=None`, the top left y axis label is no longer hidden (:pr:`2850`). - |Enhancement| It is now possible to plot a discrete :func:`histplot` as a step function or polygon (:pr:`2859`). - |Enhancement| It is now possible to customize the appearance of elements in a :func:`boxenplot` with `box_kws`/`line_kws`/`flier_kws` (:pr:`2909`). - |Fix| Improved integration with the matplotlib color cycle in most axes-level functions (:pr:`2449`). - |Fix| Fixed a regression in 0.11.2 that caused some functions to stall indefinitely or raise when the input data had a duplicate index (:pr:`2776`). - |Fix| Fixed a bug in :func:`histplot` and :func:`kdeplot` where weights were not factored into the normalization (:pr:`2812`). - |Fix| Fixed two edgecases in :func:`histplot` when only `binwidth` was provided (:pr:`2813`). - |Fix| Fixed a bug in :func:`violinplot` where inner boxes/points could be missing with unpaired split violins (:pr:`2814`). - |Fix| Fixed a bug in :class:`PairGrid` where an error would be raised when defining `hue` only in the mapping methods (:pr:`2847`). - |Fix| Fixed a bug in :func:`scatterplot` where an error would be raised when `hue_order` was a subset of the hue levels (:pr:`2848`). - |Fix| Fixed a bug in :func:`histplot` where dodged bars would have different widths on a log scale (:pr:`2849`). - |Fix| In :func:`lineplot`, allowed the `dashes` keyword to set the style of a line without mapping a `style` variable (:pr:`2449`). - |Fix| Improved support in :func:`relplot` for "wide" data and for faceting variables passed as non-pandas objects (:pr:`2846`). - |Fix| Subplot titles will no longer be reset when calling :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe` (:pr:`2705`). - |Fix| Added a workaround for a matplotlib issue that caused figure-level functions to freeze when `plt.show` was called (:pr:`2925`). - |Fix| Improved robustness to numerical errors in :func:`kdeplot` (:pr:`2862`). - |Fix| Fixed a bug where :func:`rugplot` was ignoring expand_margins=False (:pr:`2953`). - |Defaults| The `patch.facecolor` rc param is no longer set by :func:`set_palette` (or :func:`set_theme`). This should have no general effect, because the matplotlib default is now `"C0"` (:pr:`2906`). - |Build| Made `scipy` an optional dependency and added `pip install seaborn[stats]` as a method for ensuring the availability of compatible `scipy` and `statsmodels` libraries at install time. This has a few minor implications for existing code, which are explained in the Github pull request (:pr:`2398`). - |Build| Example datasets are now stored in an OS-specific cache location (as determined by `appdirs`) rather than in the user's home directory. Users should feel free to remove `~/seaborn-data` if desired (:pr:`2773`). - |Build| The unit test suite is no longer part of the source or wheel distribution. Seaborn has never had a runtime API for exercising the tests, so this should not have workflow implications (:pr:`2833`). - |Build| Following `NEP29 `_, dropped support for Python 3.6 and bumped the minimally-supported versions of the library dependencies. - |API| Removed the previously-deprecated `factorplot` along with several previously-deprecated utility functions (`iqr`, `percentiles`, `pmf_hist`, and `sort_df`). - |API| Removed the (previously-unused) option to pass additional keyword arguments to :func:`pointplot`. ================================================ FILE: doc/whatsnew/v0.12.1.rst ================================================ v0.12.1 (October 2022) ---------------------- This is an incremental release that is a recommended upgrade for all users. It addresses a handful of bugs / regressions in v0.12.0 and adds several features and enhancements to the new :doc:`objects interface `. - |Feature| Added the :class:`objects.Text` mark (:pr:`3051`). - |Feature| Added the :class:`objects.Dash` mark (:pr:`3074`). - |Feature| Added the :class:`objects.Perc` stat (:pr:`3063`). - |Feature| Added the :class:`objects.Count` stat (:pr:`3086`). - |Feature| The :class:`objects.Band` and :class:`objects.Range` marks will now cover the full extent of the data if `min` / `max` variables are not explicitly assigned or added in a transform (:pr:`3056`). - |Enhancement| |Defaults| The :class:`objects.Jitter` move now applies a small amount of jitter by default (:pr:`3066`). - |Enhancement| |Defaults| Axes with a :class:`objects.Nominal` scale now appear like categorical axes in classic seaborn, with fixed margins, no grid, and an inverted y axis (:pr:`3069`). - |Enhancement| |API| The :meth:`objects.Continuous.label` method now accepts `base=None` to override the default formatter with a log transform (:pr:`3087`). - |Enhancement| |Fix| Marks that sort along the orient axis (e.g. :class:`objects.Line`) now use a stable algorithm (:pr:`3064`). - |Enhancement| |Fix| Added a `label` parameter to :func:`pointplot`, which addresses a regression in 0.12.0 when :func:`pointplot` is passed to :class:`FacetGrid` (:pr:`3016`). - |Fix| Fixed a bug that caused an exception when more than two layers with the same mappings were added to :class:`objects.Plot` (:pr:`3055`). - |Fix| Made :class:`objects.PolyFit` robust to missing data (:pr:`3010`). - |Fix| Fixed a bug in :class:`objects.Plot` that occurred when data assigned to the orient coordinate had zero variance (:pr:`3084`). - |Fix| Fixed a regression in :func:`kdeplot` where passing `cmap` for an unfilled bivariate plot would raise an exception (:pr:`3065`). - |Fix| Addressed a performance regression in :func:`lineplot` with a large number of unique x values (:pr:`3081`). - |Build| Seaborn no longer contains doctest-style examples, simplifying the testing infrastructure (:pr:`3034`). ================================================ FILE: doc/whatsnew/v0.12.2.rst ================================================ v0.12.2 (December 2022) ----------------------- This is an incremental release that is a recommended upgrade for all users. It is very likely the final release of the 0.12 series and the last version to support Python 3.7. - |Feature| Added the :class:`objects.KDE` stat (:pr:`3111`). - |Feature| Added the :class:`objects.Boolean` scale (:pr:`3205`). - |Enhancement| Improved user feedback for failures during plot compilation by catching exceptions and re-raising with a `PlotSpecError` that provides additional context. (:pr:`3203`). - |Fix| Improved calculation of automatic mark widths with unshared facet axes (:pr:`3119`). - |Fix| Improved robustness to empty data in several components of the objects interface (:pr:`3202`). - |Fix| Fixed a bug where legends for numeric variables with large values would be incorrectly shown (i.e. with a missing offset or exponent; :pr:`3187`). - |Fix| Fixed a regression in v0.12.0 where manually-added labels could have duplicate legend entries (:pr:`3116`). - |Fix| Fixed a bug in :func:`histplot` with `kde=True` and `log_scale=True` where the curve was not scaled properly (:pr:`3173`). - |Fix| Fixed a bug in :func:`relplot` where inner axis labels would be shown when axis sharing was disabled (:pr:`3180`). - |Fix| Fixed a bug in :class:`objects.Continuous` to avoid an exception with boolean data (:pr:`3189`). ================================================ FILE: doc/whatsnew/v0.13.0.rst ================================================ v0.13.0 (September 2023) ------------------------ This is a major release with a number of important new features and changes. The highlight is a major overhaul to seaborn's categorical plotting functions, providing them with many new capabilities and better aligning their API with the rest of the library. There is also provisional support for alternate dataframe libraries like `polars `_, a new theme and display configuration system for :class:`objects.Plot`, and many smaller bugfixes and enhancements. Updating is recommended, but users are encouraged to carefully check the outputs of existing code that uses the categorical functions, and they should be aware of some deprecations and intentional changes to the default appearance of the resulting plots (see notes below with |API| and |Defaults| tags). Major enhancements to categorical plots ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ Seaborn's :ref:`categorical functions ` have been completely rewritten for this release. This provided the opportunity to address some longstanding quirks as well as to add a number of smaller but much-desired features and enhancements. Support for numeric and datetime data ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |Feature| The categorical functions have historically treated *all* data as categorical, even when it has a numeric or datetime type. This can now be controlled with the new `native_scale` parameter. The default remains `False` to preserve existing behavior. But with `native_scale=True`, values will be treated as they would by other seaborn or matplotlib functions. Element widths will be derived from the minimum distance between two unique values on the categorical axis. Additionally, while seaborn previously determined the mapping from categorical values to ordinal positions internally, this is now delegated to matplotlib. The change should mostly be transparent to the user, but categorical plots (even with `native_scale=False`) will better align with artists added by other seaborn or matplotlib functions in most cases, and matplotlib's interactive machinery will work better. Changes to color defaults and specification ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ |API| |Defaults| The categorical functions now act more like the rest of seaborn in that they will produce a plot with a single main color unless the `hue` variable is assigned. Previously, there would be an implicit redundant color mapping (e.g., each box in a boxplot would get a separate color from the default palette). To retain the previous behavior, explicitly assign a redundant `hue` variable (e.g., `boxplot(data, x="x", y="y", hue="x")`). Two related idiosyncratic color specifications are deprecated, but they will continue to work (with a warning) for one release cycle: - Passing a `palette` without explicitly assigning `hue` is no longer supported (add an explicitly redundant `hue` assignment instead). - Passing a `color` while assigning `hue` to produce a gradient is no longer supported (use `palette="dark:{color}"` or `palette="light:{color}"` instead). Finally, like other seaborn functions, the default palette now depends on the variable type, and a sequential palette will be used with numeric data. To retain the previous behavior, pass the name of a qualitative palette (e.g., `palette="deep"` for seaborn's default). Accordingly, the functions have gained a parameter to control numeric color mappings (`hue_norm`). Other features, enhancements, and changes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The following updates apply to multiple categorical functions. - |Feature| All functions now accept a `legend` parameter, which can be a boolean (to suppress the legend) or one of `{"auto", "brief", "full"}` to control the amount of information shown in the legend for a numerical color mapping. - |Feature| All functions now accept a callable `formatter` parameter to control the string representation of the data. - |Feature| All functions that draw a solid patch now accept a boolean `fill` parameter, which when set to `False` will draw line-art elements. - |Feature| All functions that support dodging now have an additional `gap` parameter that can be set to a non-zero value to leave space between dodged elements. - |Feature| The :func:`boxplot`, :func:`boxenplot`, and :func:`violinplot` functions now support a single `linecolor` parameter. - |Enhancement| The default value for `dodge` has changed from `True` to `"auto"`. With `"auto"`, elements will dodge only when at least one set of elements would otherwise overlap. - |Enhancement| When the value axis of the plot has a non-linear scale, the statistical operations (e.g. an aggregation in :func:`pointplot` or the kernel density fit in :func:`violinplot`) are now applied in that scale space. - |Enhancement| All functions now accept a `log_scale` parameter. With a single argument, this will set the scale on the "value" axis (*opposite* the categorical axis). A tuple will set each axis directly (although setting a log scale categorical axis also requires `native_scale=True`). - |Enhancement| The `orient` parameter now accepts `"x"/"y"` to specify the categorical axis, matching the objects interface. - |Enhancement| The categorical functions are generally more deferential to the user's additional matplotlib keyword arguments. - |API| Using `"gray"` to select an automatic gray value that complements the main palette is now deprecated in favor of `"auto"`. The following updates are function-specific. - |API| |Feature| In :func:`pointplot`, a single :class:`matplotlib.lines.Line2D` artist is now used rather than adding separate :class:`matplotlib.collections.PathCollection` artist for the points. As a result, it is now possible to pass additional keyword arguments for complete customization the appearance of both the lines and markers; additionally, the legend representation is improved. Accordingly, parameters that previously allowed only partial customization (`scale`, `join`, and `errwidth`) are now deprecated. The old parameters will now trigger detailed warning messages with instructions for adapting existing code. - |API| |Feature| The bandwidth specification in :func:`violinplot` better aligns with :func:`kdeplot`, as the `bw` parameter is now deprecated in favor of `bw_method` and `bw_adjust`. - |API| |Enhancement| In :func:`boxenplot`, the boxen are now drawn with separate patch artists in each tail. This may have consequences for code that works with the underlying artists, but it produces a better result for low-alpha / unfilled plots and enables proper area/density scaling. - |API| |Enhancement| In :func:`barplot`, the `errcolor` and `errwidth` parameters are now deprecated in favor of a more general `err_kws`` dictionary. The existing parameters will continue to work for two releases. - |API| In :func:`violinplot`, the `scale` and `scale_hue` parameters have been renamed to `density_norm` and `common_norm` for clarity and to reflect the fact that common normalization is now applied over both hue and faceting variables in :func:`catplot`. - |API| In :func:`boxenplot`, the `scale` parameter has been renamed to `width_method` as part of a broader effort to de-confound the meaning of "scale" in seaborn parameters. - |Defaults| |Enhancement| When passing a vector to the `data` parameter of :func:`barplot` or :func:`pointplot`, a bar or point will be drawn for each entry in the vector rather than plotting a single aggregated value. To retain the previous behavior, assign the vector to the `y` variable. - |Defaults| |Enhancement| In :func:`boxplot`, the default flier marker now follows the matplotlib rcparams so that it can be globally customized. - |Defaults| |Enhancement| When using `split=True` and `inner="box"` in :func:`violinplot`, a separate mini-box is now drawn for each split violin. - |Defaults| |Enhancement| In :func:`boxenplot`, all plots now use a consistent luminance ramp for the different box levels. This leads to a change in the appearance of existing plots, but reduces the chances of a misleading result. - |Defaults| |Enhancement| The `"area"` scaling in :func:`boxenplot` now approximates the density of the underlying observations, including for asymmetric distributions. This produces a substantial change in the appearance of plots with `width_method="area"`, although the existing behavior was poorly defined. - |Feature| In :func:`countplot`, the new `stat` parameter can be used to apply a normalization (e.g to show a `"percent"` or `"proportion"`). - |Feature| The `split` parameter in :func:`violinplot` is now more general and can be set to `True` regardless of the number of `hue` variable levels (or even without `hue`). This is probably most useful for showing half violins. - |Feature| In :func:`violinplot`, the new `inner_kws` parameter allows additional control over the interior artists. - |Enhancement| It is no longer required to use a `DataFrame` in :func:`catplot`, as data vectors can now be passed directly. - |Enhancement| In :func:`boxplot`, the artists that comprise each box plot are now packaged in a `BoxPlotContainer` for easier post-plotting access. Support for alternate dataframe libraries ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - |Feature| Nearly all functions / objects now use the `dataframe exchange protocol `_ to accept `DataFrame` objects from libraries other than `pandas` (e.g. `polars`). Note that seaborn will still convert the data object to pandas internally, but this feature will simplify code for users of other dataframe libraries (:pr:`3369`). Improved configuration for the objects interface ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ - |Feature| Added control over the default theme to :class:`objects.Plot` (:pr:`3223`) - |Feature| Added control over the default notebook display to :class:`objects.Plot` (:pr:`3225`). - |Feature| Added the concept of a "layer legend" in :class:`objects.Plot` via the new `label` parameter in :meth:`objects.Plot.add` (:pr:`3456`). - |Enhancement| In :meth:`objects.Plot.scale`, :meth:`objects.Plot.limit`, and :meth:`objects.Plot.label` the `x` / `y` parameters can be used to set a common scale / limit / label for paired subplots (:pr:`3458`). Other updates ^^^^^^^^^^^^^ - |Enhancement| Improved the legend display for relational and categorical functions to better represent the user's additional keyword arguments (:pr:`3467`). - |Enhancement| In :func:`ecdfplot`, `stat="percent"` is now a valid option (:pr:`3336`). - |Enhancement| Data values outside the scale transform domain (e.g. non-positive values with a log scale) are now dropped prior to any statistical operations (:pr:`3488`). - |Enhancement| In :func:`histplot`, infinite values are now ignored when choosing the default bin range (:pr:`3488`). - |Enhancement| There is now generalized support for performing statistics in the appropriate space based on axes scales; previously support for this was spotty and at best worked only for log scales (:pr:`3440`). - |Enhancement| Updated :func:`load_dataset` to use an approach more compatible with `pyiodide` (:pr:`3234`). - |API| Support for array-typed palettes is now deprecated. This was not previously documented as supported, but it worked by accident in a few places (:pr:`3452`). - |API| |Fix| In :func:`histplot`, treatment of the `binwidth` parameter has changed such that the actual bin width will be only approximately equal to the requested width when that value does not evenly divide the bin range. This fixes an issue where the largest data value was sometimes dropped due to floating point error (:pr:`3489`). - |Fix| Fixed :class:`objects.Bar` and :class:`objects.Bars` widths when using a nonlinear scale (:pr:`3217`). - |Fix| Worked around an issue in matplotlib that caused incorrect results in :func:`move_legend` when `labels` were provided (:pr:`3454`). - |Fix| Fixed a bug introduced in v0.12.0 where :func:`histplot` added a stray empty `BarContainer` (:pr:`3246`). - |Fix| Fixed a bug where :meth:`objects.Plot.on` would override a figure's layout engine (:pr:`3216`). - |Fix| Fixed a bug introduced in v0.12.0 where :func:`lineplot` with a list of tuples for the keyword argument dashes caused a TypeError (:pr:`3316`). - |Fix| Fixed a bug in :class:`PairGrid` that caused an exception when the input dataframe had a column multiindex (:pr:`3407`). - |Fix| Improved a few edge cases when using pandas nullable dtypes (:pr:`3394`). ================================================ FILE: doc/whatsnew/v0.13.1.rst ================================================ v0.13.1 (December 2023) ----------------------- This is a minor release with some bug fixes and a couple new features. All users are encouraged to update. - |Feature| Added support for weighted mean estimation (with boostrap CIs) in :func:`lineplot`, :func:`barplot`, :func:`pointplot`, and :class:`objects.Est` (:pr:`3580`, :pr:`3586`). - |Feature| Added the `extent` option to :meth:`objects.Plot.layout` (:pr:`3552`). - |Fix| Fixed a regression in v0.13.0 that triggered an exception when working with non-numpy data types (:pr:`3516`). - |Fix| Fixed a bug in :class:`objects.Plot` so that tick labels are shown for wrapped axes that aren't in the bottom-most row (:pr:`3600`). - |Fix| Fixed a bug in :func:`catplot` where a blank legend would be added when `hue` was redundantly assigned (:pr:`3540`). - |Fix| Fixed a bug in :func:`catplot` where the `edgecolor` parameter was ignored with `kind="bar"` (:pr:`3547`). - |Fix| Fixed a bug in :func:`boxplot` where an exception was raised when using the matplotlib `bootstrap` option (:pr:`3562`). - |Fix| Fixed a bug in :func:`lineplot` where an exception was raised when `hue` was assigned with an empty dataframe (:pr:`3569`). - |Fix| Fixed a bug in multiple categorical plots that raised with `hue=None` and `dodge=True`; this is now has no effect (:pr:`3605`). ================================================ FILE: doc/whatsnew/v0.13.2.rst ================================================ v0.13.2 (January 2024) ---------------------- This is a minor release containing internal changes that adapt to upcoming deprecations in pandas. All users are encouraged to update. ================================================ FILE: doc/whatsnew/v0.2.0.rst ================================================ v0.2.0 (December 2013) ---------------------- This is a major release from 0.1 with a number of API changes, enhancements, and bug fixes. Highlights include an overhaul of timeseries plotting to work intelligently with dataframes, the new function ``interactplot()`` for visualizing continuous interactions, bivariate kernel density estimates in ``kdeplot()``, and significant improvements to color palette handling. Version 0.2 also introduces experimental support for Python 3. In addition to the library enhancements, the documentation has been substantially rewritten to reflect the new features and improve the presentation of the ideas behind the package. API changes ~~~~~~~~~~~ - The ``tsplot()`` function was rewritten to accept data in a long-form ``DataFrame`` and to plot different traces by condition. This introduced a relatively minor but unavoidable API change, where instead of doing ``sns.tsplot(time, heights)``, you now must do ``sns.tsplot(heights, time=time)`` (the ``time`` parameter is now optional, for quicker specification of simple plots). Additionally, the ``"obs_traces"`` and ``"obs_points"`` error styles in ``tsplot()`` have been renamed to ``"unit_traces"`` and ``"unit_points"``, respectively. - Functions that fit kernel density estimates (``kdeplot()`` and ``violinplot()``) now use ``statsmodels`` instead of ``scipy``, and the parameters that influence the density estimate have changed accordingly. This allows for increased flexibility in specifying the bandwidth and kernel, and smarter choices for defining the range of the support. Default options should produce plots that are very close to the old defaults. - The ``kdeplot()`` function now takes a second positional argument of data for drawing bivariate densities. - The ``violin()`` function has been changed to ``violinplot()``, for consistency. In 0.2, ``violin`` will still work, but it will fire a ``UserWarning``. New plotting functions ~~~~~~~~~~~~~~~~~~~~~~ - The ``interactplot()`` function draws a contour plot for an interactive linear model (i.e., the contour shows ``y-hat`` from the model ``y ~ x1 * x2``) over a scatterplot between the two predictor variables. This plot should aid the understanding of an interaction between two continuous variables. - The ``kdeplot()`` function can now draw a bivariate density estimate as a contour plot if provided with two-dimensional input data. - The ``palplot()`` function provides a simple grid-based visualization of a color palette. Other changes ~~~~~~~~~~~~~ Plotting functions ^^^^^^^^^^^^^^^^^^ - The ``corrplot()`` function can be drawn without the correlation coefficient annotation and with variable names on the side of the plot to work with large datasets. - Additionally, ``corrplot()`` sets the color palette intelligently based on the direction of the specified test. - The ``distplot()`` histogram uses a reference rule to choose the bin size if it is not provided. - Added the ``x_bins`` option in ``lmplot()`` for binning a continuous predictor variable, allowing for clearer trends with many datapoints. - Enhanced support for labeling plot elements and axes based on ``name`` attributes in several distribution plot functions and ``tsplot()`` for smarter Pandas integration. - Scatter points in ``lmplot()`` are slightly transparent so it is easy to see where observations overlap. - Added the ``order`` parameter to ``boxplot()`` and ``violinplot()`` to control the order of the bins when using a Pandas object. - When an ``ax`` argument is not provided to a plotting function, it grabs the currently active axis instead of drawing a new one. Color palettes ^^^^^^^^^^^^^^ - Added the ``dark_palette()`` and ``blend_palette()`` for on-the-fly creation of blended color palettes. - The color palette machinery is now intelligent about qualitative ColorBrewer palettes (``Set1``, ``Paired``, etc.), which are properly treated as discrete. - Seaborn color palettes (``deep``, ``muted``, etc.) have been standardized in terms of basic hue sequence, and all palettes now have 6 colors. - Introduced ``{mpl_palette}_d`` palettes, which make a palette with the basic color scheme of the source palette, but with a sequential blend from dark instead of light colors for use with line/scatter/contour plots. - Added the ``palette_context()`` function for blockwise color palettes controlled by a ``with`` statement. Plot styling ^^^^^^^^^^^^ - Added the ``despine()`` function for easily removing plot spines. - A new plot style, ``"ticks"`` has been added. - Tick labels are padded a bit farther from the axis in all styles, avoiding collisions at (0, 0). General package issues ^^^^^^^^^^^^^^^^^^^^^^ - Reorganized the package by breaking up the monolithic ``plotobjs`` module into smaller modules grouped by general objective of the constituent plots. - Removed the ``scikits-learn`` dependency in ``moss``. - Installing with ``pip`` should automatically install most missing dependencies. - The example notebooks are now used as an automated test suite. Bug fixes ~~~~~~~~~ - Fixed a bug where labels did not match data for ``boxplot()`` and ``violinplot()`` when using a groupby. - Fixed a bug in the ``desaturate()`` function. - Fixed a bug in the ``coefplot()`` figure size calculation. - Fixed a bug where ``regplot()`` choked on list input. - Fixed buggy behavior when drawing horizontal boxplots. - Specifying bins for the ``distplot()`` histogram now works. - Fixed a bug where ``kdeplot()`` would reset the axis height and cut off existing data. - All axis styling has been moved out of the top-level ``seaborn.set()`` function, so context or color palette can be cleanly changed. ================================================ FILE: doc/whatsnew/v0.2.1.rst ================================================ v0.2.1 (December 2013) ---------------------- This is a bugfix release, with no new features. Bug fixes ~~~~~~~~~ - Changed the mechanics of ``violinplot()`` and ``boxplot()`` when using a ``Series`` object as data and performing a ``groupby`` to assign data to bins to address a problem that arises in Pandas 0.13. - Additionally fixed the ``groupby`` code to work with all styles of group specification (specifically, using a dictionary or a function now works). - Fixed a bug where artifacts from the kde fitting could undershoot and create a plot where the density axis starts below 0. - Ensured that data used for kde fitting is double-typed to avoid a low-level statsmodels error. - Changed the implementation of the histogram bin-width reference rule to take a ceiling of the estimated number of bins. ================================================ FILE: doc/whatsnew/v0.3.0.rst ================================================ v0.3.0 (March 2014) ------------------- This is a major release from 0.2 with a number of enhancements to the plotting capabilities and styles. Highlights include :class:`FacetGrid`, ``factorplot``, :func:`jointplot`, and an overhaul to :ref:`style management `. There is also lots of new documentation, including an :ref:`example gallery ` and reorganized :ref:`tutorial `. New plotting functions ~~~~~~~~~~~~~~~~~~~~~~ - The :class:`FacetGrid` class adds a new form of functionality to seaborn, providing a way to abstractly structure a grid of plots corresponding to subsets of a dataset. It can be used with a wide variety of plotting functions (including most of the matplotlib and seaborn APIs. See the :ref:`tutorial ` for more information. - Version 0.3 introduces the ``factorplot`` function, which is similar in spirit to :func:`lmplot` but intended for use when the main independent variable is categorical instead of quantitative. ``factorplot`` can draw a plot in either a point or bar representation using the corresponding Axes-level functions :func:`pointplot` and :func:`barplot` (which are also new). Additionally, the ``factorplot`` function can be used to draw box plots on a faceted grid. For examples of how to use these functions, you can refer to the tutorial. - Another new function is :func:`jointplot`, which is built using the new :class:`JointGrid` object. :func:`jointplot` generalizes the behavior of :func:`regplot` in previous versions of seaborn (:func:`regplot` has changed somewhat in 0.3; see below for details) by drawing a bivariate plot of the relationship between two variables with their marginal distributions drawn on the side of the plot. With :func:`jointplot`, you can draw a scatterplot or regression plot as before, but you can now also draw bivariate kernel densities or hexbin plots with appropriate univariate graphs for the marginal distributions. Additionally, it's easy to use :class:`JointGrid` directly to build up more complex plots when the default methods offered by :func:`jointplot` are not suitable for your visualization problem. The tutorial for :class:`JointGrid` has more examples of how this object can be useful. - The :func:`residplot` function complements :func:`regplot` and can be quickly used to diagnose problems with a linear model by calculating and plotting the residuals of a simple regression. There is also a ``"resid"`` kind for :func:`jointplot`. API changes ~~~~~~~~~~~ - The most noticeable change will be that :func:`regplot` no longer produces a multi-component plot with distributions in marginal axes. Instead. :func:`regplot` is now an "Axes-level" function that can be plotted into any existing figure on a specific set of axes. :func:`regplot` and :func:`lmplot` have also been unified (the latter uses the former behind the scenes), so all options for how to fit and represent the regression model can be used for both functions. To get the old behavior of :func:`regplot`, use :func:`jointplot` with ``kind="reg"``. - As noted above, :func:`lmplot` has been rewritten to exploit the :class:`FacetGrid` machinery. This involves a few changes. The ``color`` keyword argument has been replaced with ``hue``, for better consistency across the package. The ``hue`` parameter will always take a variable *name*, while ``color`` will take a color name or (in some cases) a palette. The :func:`lmplot` function now returns the :class:`FacetGrid` used to draw the plot instance. - The functions that interact with matplotlib rc parameters have been updated and standardized. There are now three pairs of functions, :func:`axes_style` and :func:`set_style`, :func:`plotting_context` and :func:`set_context`, and :func:`color_palette` and :func:`set_palette`. In each case, the pairs take the exact same arguments. The first function defines and returns the parameters, and the second sets the matplotlib defaults. Additionally, the first function in each pair can be used in a ``with`` statement to temporarily change the defaults. Both the style and context functions also now accept a dictionary of matplotlib rc parameters to override the seaborn defaults, and :func:`set` now also takes a dictionary to update any of the matplotlib defaults. See the :ref:`tutorial ` for more information. - The ``nogrid`` style has been deprecated and changed to ``white`` for more uniformity (i.e. there are now ``darkgrid``, ``dark``, ``whitegrid``, and ``white`` styles). Other changes ~~~~~~~~~~~~~ Using the package ^^^^^^^^^^^^^^^^^ - If you want to use plotting functions provided by the package without setting the matplotlib style to a seaborn theme, you can now do ``import seaborn.apionly as sns`` or ``from seaborn.apionly import lmplot``, etc. This is using the (also new) :func:`reset_orig` function, which returns the rc parameters to what they are at matplotlib import time — i.e. they will respect any custom `matplotlibrc` settings on top of the matplotlib defaults. - The dependency load of the package has been reduced. It can now be installed and used with only ``numpy``, ``scipy``, ``matplotlib``, and ``pandas``. Although ``statsmodels`` is still recommended for full functionality, it is not required. Plotting functions ^^^^^^^^^^^^^^^^^^ - :func:`lmplot` (and :func:`regplot`) have two new options for fitting regression models: ``lowess`` and ``robust``. The former fits a nonparametric smoother, while the latter fits a regression using methods that are less sensitive to outliers. - The regression uncertainty in :func:`lmplot` and :func:`regplot` is now estimated with fewer bootstrap iterations, so plotting should be faster. - The univariate :func:`kdeplot` can now be drawn as a *cumulative* density plot. - Changed :func:`interactplot` to use a robust calculation of the data range when finding default limits for the contour colormap to work better when there are outliers in the data. Style ^^^^^ - There is a new style, ``dark``, which shares most features with ``darkgrid`` but does not draw a grid by default. - There is a new function, :func:`offset_spines`, and a corresponding option in :func:`despine` called ``trim``. Together, these can be used to make plots where the axis spines are offset from the main part of the figure and limited within the range of the ticks. This is recommended for use with the ``ticks`` style. - Other aspects of the seaborn styles have been tweaked for more attractive plots. ================================================ FILE: doc/whatsnew/v0.3.1.rst ================================================ v0.3.1 (April 2014) ------------------- This is a minor release from 0.3 with fixes for several bugs. Plotting functions ~~~~~~~~~~~~~~~~~~ - The size of the points in :func:`pointplot` and ``factorplot`` are now scaled with the linewidth for better aesthetics across different plotting contexts. - The :func:`pointplot` glyphs for different levels of the hue variable are drawn at different z-orders so that they appear uniform. Bug Fixes ~~~~~~~~~ - Fixed a bug in :class:`FacetGrid` (and thus affecting lmplot and factorplot) that appeared when ``col_wrap`` was used with a number of facets that did not evenly divide into the column width. - Fixed an issue where the support for kernel density estimates was sometimes computed incorrectly. - Fixed a problem where ``hue`` variable levels that were not strings were missing in :class:`FacetGrid` legends. - When passing a color palette list in a ``with`` statement, the entire palette is now used instead of the first six colors. ================================================ FILE: doc/whatsnew/v0.4.0.rst ================================================ v0.4.0 (September 2014) ----------------------- This is a major release from 0.3. Highlights include new approaches for :ref:`quick, high-level dataset exploration ` (along with a more :ref:`flexible interface `) and easy creation of :ref:`perceptually-appropriate color palettes ` using the cubehelix system. Along with these additions, there are a number of smaller changes that make visualizing data with seaborn easier and more powerful. Plotting functions ~~~~~~~~~~~~~~~~~~ - A new object, :class:`PairGrid`, and a corresponding function :func:`pairplot`, for drawing grids of pairwise relationships in a dataset. This style of plot is sometimes called a "scatterplot matrix", but the representation of the data in :class:`PairGrid` is flexible and many styles other than scatterplots can be used. See the :ref:`docs ` for more information. **Note:** due to a bug in older versions of matplotlib, you will have best results if you use these functions with matplotlib 1.4 or later. - The rules for choosing default color palettes when variables are mapped to different colors have been unified (and thus changed in some cases). Now when no specific palette is requested, the current global color palette will be used, unless the number of variables to be mapped exceeds the number of unique colors in the palette, in which case the ``"husl"`` palette will be used to avoid cycling. - Added a keyword argument ``hist_norm`` to :func:`distplot`. When a :func:`distplot` is now drawn without a KDE or parametric density, the histogram is drawn as counts instead of a density. This can be overridden by by setting ``hist_norm`` to ``True``. - When using :class:`FacetGrid` with a ``hue`` variable, the legend is no longer drawn by default when you call :meth:`FacetGrid.map`. Instead, you have to call :meth:`FacetGrid.add_legend` manually. This should make it easier to layer multiple plots onto the grid without having duplicated legends. - Made some changes to ``factorplot`` so that it behaves better when not all levels of the ``x`` variable are represented in each facet. - Added the ``logx`` option to :func:`regplot` for fitting the regression in log space. - When :func:`violinplot` encounters a bin with only a single observation, it will now plot a horizontal line at that value instead of erroring out. Style and color palettes ~~~~~~~~~~~~~~~~~~~~~~~~ - Added the :func:`cubehelix_palette` function for generating sequential palettes from the cubehelix system. See the :ref:`palette docs ` for more information on how these palettes can be used. There is also the :func:`choose_cubehelix` which will launch an interactive app to select cubehelix parameters in the notebook. - Added the :func:`xkcd_palette` and the ``xkcd_rgb`` dictionary so that colors can be specified with names from the `xkcd color survey `_. - Added the ``font_scale`` option to :func:`plotting_context`, :func:`set_context`, and :func:`set`. ``font_scale`` can independently increase or decrease the size of the font elements in the plot. - Font-handling should work better on systems without Arial installed. This is accomplished by adding the ``font.sans-serif`` field to the ``axes_style`` definition with Arial and Liberation Sans prepended to matplotlib defaults. The font family can also be set through the ``font`` keyword argument in :func:`set`. Due to matplotlib bugs, this might not work as expected on matplotlib 1.3. - The :func:`despine` function gets a new keyword argument ``offset``, which replaces the deprecated :func:`offset_spines` function. You no longer need to offset the spines before plotting data. - Added a default value for ``pdf.fonttype`` so that text in PDFs is editable in Adobe Illustrator. Other API Changes ~~~~~~~~~~~~~~~~~ - Removed the deprecated ``set_color_palette`` and ``palette_context`` functions. These were replaced in version 0.3 by the :func:`set_palette` function and ability to use :func:`color_palette` directly in a ``with`` statement. - Removed the ability to specify a ``nogrid`` style, which was renamed to ``white`` in 0.3. ================================================ FILE: doc/whatsnew/v0.5.0.rst ================================================ v0.5.0 (November 2014) -------------------------- This is a major release from 0.4. Highlights include new functions for plotting heatmaps, possibly while applying clustering algorithms to discover structured relationships. These functions are complemented by new custom colormap functions and a full set of IPython widgets that allow interactive selection of colormap parameters. The palette tutorial has been rewritten to cover these new tools and more generally provide guidance on how to use color in visualizations. There are also a number of smaller changes and bugfixes. Plotting functions ~~~~~~~~~~~~~~~~~~ - Added the :func:`heatmap` function for visualizing a matrix of data by color-encoding the values. See the docs for more information. - Added the :func:`clustermap` function for clustering and visualizing a matrix of data, with options to label individual rows and columns by colors. See the docs for more information. This work was lead by Olga Botvinnik. - :func:`lmplot` and :func:`pairplot` get a new keyword argument, ``markers``. This can be a single kind of marker or a list of different markers for each level of the ``hue`` variable. Using different markers for different hues should let plots be more comprehensible when reproduced to black-and-white (i.e. when printed). See the `github pull request (#323) `_ for examples. - More generally, there is a new keyword argument in :class:`FacetGrid` and :class:`PairGrid`, ``hue_kws``. This similarly lets plot aesthetics vary across the levels of the hue variable, but more flexibly. ``hue_kws`` should be a dictionary that maps the name of keyword arguments to lists of values that are as long as the number of levels of the hue variable. - The argument ``subplot_kws`` has been added to ``FacetGrid``. This allows for faceted plots with custom projections, including `maps with Cartopy `_. Color palettes ~~~~~~~~~~~~~~ - Added two new functions to create custom color palettes. For sequential palettes, you can use the :func:`light_palette` function, which takes a seed color and creates a ramp from a very light, desaturated variant of it. For diverging palettes, you can use the :func:`diverging_palette` function to create a balanced ramp between two endpoints to a light or dark midpoint. See the :ref:`palette tutorial ` for more information. - Added the ability to specify the seed color for :func:`light_palette` and :func:`dark_palette` as a tuple of ``husl`` or ``hls`` space values or as a named ``xkcd`` color. The interpretation of the seed color is now provided by the new ``input`` parameter to these functions. - Added several new interactive palette widgets: :func:`choose_colorbrewer_palette`, :func:`choose_light_palette`, :func:`choose_dark_palette`, and :func:`choose_diverging_palette`. For consistency, renamed the cubehelix widget to :func:`choose_cubehelix_palette` (and fixed a bug where the cubehelix palette was reversed). These functions also now return either a color palette list or a matplotlib colormap when called, and that object will be live-updated as you play with the widget. This should make it easy to iterate over a plot until you find a good representation for the data. See the `Github pull request `_ or `this notebook (download it to use the widgets) `_ for more information. - Overhauled the color :ref:`palette tutorial ` to organize the discussion by class of color palette and provide more motivation behind the various choices one might make when choosing colors for their data. Bug fixes ~~~~~~~~~ - Fixed a bug in :class:`PairGrid` that gave incorrect results (or a crash) when the input DataFrame has a non-default index. - Fixed a bug in :class:`PairGrid` where passing columns with a date-like datatype raised an exception. - Fixed a bug where :func:`lmplot` would show a legend when the hue variable was also used on either the rows or columns (making the legend redundant). - Worked around a matplotlib bug that was forcing outliers in :func:`boxplot` to appear as blue. - :func:`kdeplot` now accepts pandas Series for the ``data`` and ``data2`` arguments. - Using a non-default correlation method in :func:`corrplot` now implies ``sig_stars=False`` as the permutation test used to significance values for the correlations uses a pearson metric. - Removed ``pdf.fonttype`` from the style definitions, as the value used in version 0.4 resulted in very large PDF files. ================================================ FILE: doc/whatsnew/v0.5.1.rst ================================================ v0.5.1 (November 2014) ---------------------- This is a bugfix release that includes a workaround for an issue in matplotlib 1.4.2 and fixes for two bugs in functions that were new in 0.5.0. - Implemented a workaround for a bug in matplotlib 1.4.2 that prevented point markers from being drawn when the seaborn styles had been set. See this `github issue `_ for more information. - Fixed a bug in :func:`heatmap` where the mask was vertically reversed relative to the data. - Fixed a bug in :func:`clustermap` when using nested lists of side colors. ================================================ FILE: doc/whatsnew/v0.6.0.rst ================================================ v0.6.0 (June 2015) ------------------ This is a major release from 0.5. The main objective of this release was to unify the API for categorical plots, which means that there are some relatively large API changes in some of the older functions. See below for details of those changes, which may break code written for older versions of seaborn. There are also some new functions (:func:`stripplot`, and :func:`countplot`), numerous enhancements to existing functions, and bug fixes. Additionally, the documentation has been completely revamped and expanded for the 0.6 release. Now, the API docs page for each function has multiple examples with embedded plots showing how to use the various options. These pages should be considered the most comprehensive resource for examples, and the tutorial pages are now streamlined and oriented towards a higher-level overview of the various features. Changes and updates to categorical plots ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ In version 0.6, the "categorical" plots have been unified with a common API. This new category of functions groups together plots that show the relationship between one numeric variable and one or two categorical variables. This includes plots that show distribution of the numeric variable in each bin (:func:`boxplot`, :func:`violinplot`, and :func:`stripplot`) and plots that apply a statistical estimation within each bin (:func:`pointplot`, :func:`barplot`, and :func:`countplot`). There is a new :ref:`tutorial chapter ` that introduces these functions. The categorical functions now each accept the same formats of input data and can be invoked in the same way. They can plot using long- or wide-form data, and can be drawn vertically or horizontally. When long-form data is used, the orientation of the plots is inferred from the types of the input data. Additionally, all functions natively take a ``hue`` variable to add a second layer of categorization. With the (in some cases new) API, these functions can all be drawn correctly by :class:`FacetGrid`. However, ``factorplot`` can also now create faceted versions of any of these kinds of plots, so in most cases it will be unnecessary to use :class:`FacetGrid` directly. By default, ``factorplot`` draws a point plot, but this is controlled by the ``kind`` parameter. Here are details on what has changed in the process of unifying these APIs: - Changes to :func:`boxplot` and :func:`violinplot` will probably be the most disruptive. Both functions maintain backwards-compatibility in terms of the kind of data they can accept, but the syntax has changed to be more similar to other seaborn functions. These functions are now invoked with ``x`` and/or ``y`` parameters that are either vectors of data or names of variables in a long-form DataFrame passed to the new ``data`` parameter. You can still pass wide-form DataFrames or arrays to ``data``, but it is no longer the first positional argument. See the `github pull request (#410) `_ for more information on these changes and the logic behind them. - As :func:`pointplot` and :func:`barplot` can now plot with the major categorical variable on the y axis, the ``x_order`` parameter has been renamed to ``order``. - Added a ``hue`` argument to :func:`boxplot` and :func:`violinplot`, which allows for nested grouping the plot elements by a third categorical variable. For :func:`violinplot`, this nesting can also be accomplished by splitting the violins when there are two levels of the ``hue`` variable (using ``split=True``). To make this functionality feasible, the ability to specify where the plots will be draw in data coordinates has been removed. These plots now are drawn at set positions, like (and identical to) :func:`barplot` and :func:`pointplot`. - Added a ``palette`` parameter to :func:`boxplot`/:func:`violinplot`. The ``color`` parameter still exists, but no longer does double-duty in accepting the name of a seaborn palette. ``palette`` supersedes ``color`` so that it can be used with a :class:`FacetGrid`. Along with these API changes, the following changes/enhancements were made to the plotting functions: - The default rules for ordering the categories has changed. Instead of automatically sorting the category levels, the plots now show the levels in the order they appear in the input data (i.e., the order given by ``Series.unique()``). Order can be specified when plotting with the ``order`` and ``hue_order`` parameters. Additionally, when variables are pandas objects with a "categorical" dtype, the category order is inferred from the data object. This change also affects :class:`FacetGrid` and :class:`PairGrid`. - Added the ``scale`` and ``scale_hue`` parameters to :func:`violinplot`. These control how the width of the violins are scaled. The default is ``area``, which is different from how the violins used to be drawn. Use ``scale='width'`` to get the old behavior. - Used a different style for the ``box`` kind of interior plot in :func:`violinplot`, which shows the whisker range in addition to the quartiles. Use ``inner='quartile'`` to get the old style. New plotting functions ~~~~~~~~~~~~~~~~~~~~~~ - Added the :func:`stripplot` function, which draws a scatterplot where one of the variables is categorical. This plot has the same API as :func:`boxplot` and :func:`violinplot`. It is useful both on its own and when composed with one of these other plot kinds to show both the observations and underlying distribution. - Added the :func:`countplot` function, which uses a bar plot representation to show counts of variables in one or more categorical bins. This replaces the old approach of calling :func:`barplot` without a numeric variable. Other additions and changes ~~~~~~~~~~~~~~~~~~~~~~~~~~~ - The :func:`corrplot` and underlying :func:`symmatplot` functions have been deprecated in favor of :func:`heatmap`, which is much more flexible and robust. These two functions are still available in version 0.6, but they will be removed in a future version. - Added the :func:`set_color_codes` function and the ``color_codes`` argument to :func:`set` and :func:`set_palette`. This changes the interpretation of shorthand color codes (i.e. "b", "g", k", etc.) within matplotlib to use the values from one of the named seaborn palettes (i.e. "deep", "muted", etc.). That makes it easier to have a more uniform look when using matplotlib functions directly with seaborn imported. This could be disruptive to existing plots, so it does not happen by default. It is possible this could change in the future. - The :func:`color_palette` function no longer trims palettes that are longer than 6 colors when passed into it. - Added the ``as_hex`` method to color palette objects, to return a list of hex codes rather than rgb tuples. - :func:`jointplot` now passes additional keyword arguments to the function used to draw the plot on the joint axes. - Changed the default ``linewidths`` in :func:`heatmap` and :func:`clustermap` to 0 so that larger matrices plot correctly. This parameter still exists and can be used to get the old effect of lines demarcating each cell in the heatmap (the old default ``linewidths`` was 0.5). - :func:`heatmap` and :func:`clustermap` now automatically use a mask for missing values, which previously were shown with the "under" value of the colormap per default `plt.pcolormesh` behavior. - Added the ``seaborn.crayons`` dictionary and the :func:`crayon_palette` function to define colors from the 120 box (!) of `Crayola crayons `_. - Added the ``line_kws`` parameter to :func:`residplot` to change the style of the lowess line, when used. - Added open-ended ``**kwargs`` to the ``add_legend`` method on :class:`FacetGrid` and :class:`PairGrid`, which will pass additional keyword arguments through when calling the legend function on the ``Figure`` or ``Axes``. - Added the ``gridspec_kws`` parameter to :class:`FacetGrid`, which allows for control over the size of individual facets in the grid to emphasize certain plots or account for differences in variable ranges. - The interactive palette widgets now show a continuous colorbar, rather than a discrete palette, when `as_cmap` is True. - The default Axes size for :func:`pairplot` and :class:`PairGrid` is now slightly smaller. - Added the ``shade_lowest`` parameter to :func:`kdeplot` which will set the alpha for the lowest contour level to 0, making it easier to plot multiple bivariate distributions on the same axes. - The ``height`` parameter of :func:`rugplot` is now interpreted as a function of the axis size and is invariant to changes in the data scale on that axis. The rug lines are also slightly narrower by default. - Added a catch in :func:`distplot` when calculating a default number of bins. For highly skewed data it will now use sqrt(n) bins, where previously the reference rule would return "infinite" bins and cause an exception in matplotlib. - Added a ceiling (50) to the default number of bins used for :func:`distplot` histograms. This will help avoid confusing errors with certain kinds of datasets that heavily violate the assumptions of the reference rule used to get a default number of bins. The ceiling is not applied when passing a specific number of bins. - The various property dictionaries that can be passed to ``plt.boxplot`` are now applied after the seaborn restyling to allow for full customizability. - Added a ``savefig`` method to :class:`JointGrid` that defaults to a tight bounding box to make it easier to save figures using this class, and set a tight bbox as the default for the ``savefig`` method on other Grid objects. - You can now pass an integer to the ``xticklabels`` and ``yticklabels`` parameter of :func:`heatmap` (and, by extension, :func:`clustermap`). This will make the plot use the ticklabels inferred from the data, but only plot every ``n`` label, where ``n`` is the number you pass. This can help when visualizing larger matrices with some sensible ordering to the rows or columns of the dataframe. - Added `"figure.facecolor"` to the style parameters and set the default to white. - The :func:`load_dataset` function now caches datasets locally after downloading them, and uses the local copy on subsequent calls. Bug fixes ~~~~~~~~~ - Fixed bugs in :func:`clustermap` where the mask and specified ticklabels were not being reorganized using the dendrograms. - Fixed a bug in :class:`FacetGrid` and :class:`PairGrid` that lead to incorrect legend labels when levels of the ``hue`` variable appeared in ``hue_order`` but not in the data. - Fixed a bug in :meth:`FacetGrid.set_xticklabels` or :meth:`FacetGrid.set_yticklabels` when ``col_wrap`` is being used. - Fixed a bug in :class:`PairGrid` where the ``hue_order`` parameter was ignored. - Fixed two bugs in :func:`despine` that caused errors when trying to trim the spines on plots that had inverted axes or no ticks. - Improved support for the ``margin_titles`` option in :class:`FacetGrid`, which can now be used with a legend. ================================================ FILE: doc/whatsnew/v0.7.0.rst ================================================ v0.7.0 (January 2016) --------------------- This is a major release from 0.6. The main new feature is :func:`swarmplot` which implements the beeswarm approach for drawing categorical scatterplots. There are also some performance improvements, bug fixes, and updates for compatibility with new versions of dependencies. - Added the :func:`swarmplot` function, which draws beeswarm plots. These are categorical scatterplots, similar to those produced by :func:`stripplot`, but position of the points on the categorical axis is chosen to avoid overlapping points. See the :ref:`categorical plot tutorial ` for more information. - Changed some of the :func:`stripplot` defaults to be closer to :func:`swarmplot`. Points are now somewhat smaller, have no outlines, and are not split by default when using ``hue``. These settings remain customizable through function parameters. - Added an additional rule when determining category order in categorical plots. Now, when numeric variables are used in a categorical role, the default behavior is to sort the unique levels of the variable (i.e they will be in proper numerical order). This can still be overridden by the appropriate ``{*_}order`` parameter, and variables with a ``category`` datatype will still follow the category order even if the levels are strictly numerical. - Changed how :func:`stripplot` draws points when using ``hue`` nesting with ``split=False`` so that the different ``hue`` levels are not drawn strictly on top of each other. - Improve performance for large dendrograms in :func:`clustermap`. - Added ``font.size`` to the plotting context definition so that the default output from ``plt.text`` will be scaled appropriately. - Fixed a bug in :func:`clustermap` when ``fastcluster`` is not installed. - Fixed a bug in the zscore calculation in :func:`clustermap`. - Fixed a bug in :func:`distplot` where sometimes the default number of bins would not be an integer. - Fixed a bug in :func:`stripplot` where a legend item would not appear for a ``hue`` level if there were no observations in the first group of points. - Heatmap colorbars are now rasterized for better performance in vector plots. - Added workarounds for some matplotlib boxplot issues, such as strange colors of outlier points. - Added workarounds for an issue where violinplot edges would be missing or have random colors. - Added a workaround for an issue where only one :func:`heatmap` cell would be annotated on some matplotlib backends. - Fixed a bug on newer versions of matplotlib where a colormap would be erroneously applied to scatterplots with only three observations. - Updated seaborn for compatibility with matplotlib 1.5. - Added compatibility for various IPython (and Jupyter) versions in functions that use widgets. ================================================ FILE: doc/whatsnew/v0.7.1.rst ================================================ v0.7.1 (June 2016) ------------------- - Added the ability to put "caps" on the error bars that are drawn by :func:`barplot` or :func:`pointplot` (and, by extension, ``factorplot``). Additionally, the line width of the error bars can now be controlled. These changes involve the new parameters ``capsize`` and ``errwidth``. See the `github pull request (#898) `_ for examples of usage. - Improved the row and column colors display in :func:`clustermap`. It is now possible to pass Pandas objects for these elements and, when possible, the semantic information in the Pandas objects will be used to add labels to the plot. When Pandas objects are used, the color data is matched against the main heatmap based on the index, not on position. This is more accurate, but it may lead to different results if current code assumed positional matching. - Improved the luminance calculation that determines the annotation color in :func:`heatmap`. - The ``annot`` parameter of :func:`heatmap` now accepts a rectangular dataset in addition to a boolean value. If a dataset is passed, its values will be used for the annotations, while the main dataset will be used for the heatmap cell colors. - Fixed a bug in :class:`FacetGrid` that appeared when using ``col_wrap`` with missing ``col`` levels. - Made it possible to pass a tick locator object to the :func:`heatmap` colorbar. - Made it possible to use different styles (e.g., step) for :class:`PairGrid` histograms when there are multiple hue levels. - Fixed a bug in scipy-based univariate kernel density bandwidth calculation. - The :func:`reset_orig` function (and, by extension, importing ``seaborn.apionly``) resets matplotlib rcParams to their values at the time seaborn itself was imported, which should work better with rcParams changed by the jupyter notebook backend. - Removed some objects from the top-level ``seaborn`` namespace. - Improved unicode compatibility in :class:`FacetGrid`. ================================================ FILE: doc/whatsnew/v0.8.0.rst ================================================ v0.8.0 (July 2017) ------------------ - The default style is no longer applied when seaborn is imported. It is now necessary to explicitly call :func:`set` or one or more of :func:`set_style`, :func:`set_context`, and :func:`set_palette`. Correspondingly, the ``seaborn.apionly`` module has been deprecated. - Changed the behavior of :func:`heatmap` (and by extension :func:`clustermap`) when plotting divergent dataesets (i.e. when the ``center`` parameter is used). Instead of extending the lower and upper limits of the colormap to be symmetrical around the ``center`` value, the colormap is modified so that its middle color corresponds to ``center``. This means that the full range of the colormap will not be used (unless the data or specified ``vmin`` and ``vmax`` are symmetric), but the upper and lower limits of the colorbar will correspond to the range of the data. See the Github pull request `(#1184) `_ for examples of the behavior. - Removed automatic detection of diverging data in :func:`heatmap` (and by extension :func:`clustermap`). If you want the colormap to be treated as diverging (see above), it is now necessary to specify the ``center`` value. When no colormap is specified, specifying ``center`` will still change the default to be one that is more appropriate for displaying diverging data. - Added four new colormaps, created using `viscm `_ for perceptual uniformity. The new colormaps include two sequential colormaps ("rocket" and "mako") and two diverging colormaps ("icefire" and "vlag"). These colormaps are registered with matplotlib on seaborn import and the colormap objects can be accessed in the ``seaborn.cm`` namespace. - Changed the default :func:`heatmap` colormaps to be "rocket" (in the case of sequential data) or "icefire" (in the case of diverging data). Note that this change reverses the direction of the luminance ramp from the previous defaults. While potentially confusing and disruptive, this change better aligns the seaborn defaults with the new matplotlib default colormap ("viridis") and arguably better aligns the semantics of a "heat" map with the appearance of the colormap. - Added ``"auto"`` as a (default) option for tick labels in :func:`heatmap` and :func:`clustermap`. This will try to estimate how many ticks can be labeled without the text objects overlapping, which should improve performance for larger matrices. - Added the ``dodge`` parameter to :func:`boxplot`, :func:`violinplot`, and :func:`barplot` to allow use of ``hue`` without changing the position or width of the plot elements, as when the ``hue`` variable is not nested within the main categorical variable. - Correspondingly, the ``split`` parameter for :func:`stripplot` and :func:`swarmplot` has been renamed to ``dodge`` for consistency with the other categorical functions (and for differentiation from the meaning of ``split`` in :func:`violinplot`). - Added the ability to draw a colorbar for a bivariate :func:`kdeplot` with the ``cbar`` parameter (and related ``cbar_ax`` and ``cbar_kws`` parameters). - Added the ability to use error bars to show standard deviations rather than bootstrap confidence intervals in most statistical functions by putting ``ci="sd"``. - Allow side-specific offsets in :func:`despine`. - Figure size is no longer part of the seaborn plotting context parameters. - Put a cap on the number of bins used in :func:`jointplot` for ``type=="hex"`` to avoid hanging when the reference rule prescribes too many. - Changed the y axis in :func:`heatmap`. Instead of reversing the rows of the data internally, the y axis is now inverted. This may affect code that draws on top of the heatmap in data coordinates. - Turn off dendrogram axes in :func:`clustermap` rather than setting the background color to white. - New matplotlib qualitative palettes (e.g. "tab10") are now handled correctly. - Some modules and functions have been internally reorganized; there should be no effect on code that uses the ``seaborn`` namespace. - Added a deprecation warning to ``tsplot`` function to indicate that it will be removed or replaced with a substantially altered version in a future release. - The ``interactplot`` and ``coefplot`` functions are officially deprecated and will be removed in a future release. ================================================ FILE: doc/whatsnew/v0.8.1.rst ================================================ v0.8.1 (September 2017) ----------------------- - Added a warning in :class:`FacetGrid` when passing a categorical plot function without specifying ``order`` (or ``hue_order`` when ``hue`` is used), which is likely to produce a plot that is incorrect. - Improved compatibility between :class:`FacetGrid` or :class:`PairGrid` and interactive matplotlib backends so that the legend no longer remains inside the figure when using ``legend_out=True``. - Changed categorical plot functions with small plot elements to use :func:`dark_palette` instead of :func:`light_palette` when generating a sequential palette from a specified color. - Improved robustness of :func:`kdeplot` and :func:`distplot` to data with fewer than two observations. - Fixed a bug in :func:`clustermap` when using ``yticklabels=False``. - Fixed a bug in :func:`pointplot` where colors were wrong if exactly three points were being drawn. - Fixed a bug in :func:`pointplot` where legend entries for missing data appeared with empty markers. - Fixed a bug in :func:`clustermap` where an error was raised when annotating the main heatmap and showing category colors. - Fixed a bug in :func:`clustermap` where row labels were not being properly rotated when they overlapped. - Fixed a bug in :func:`kdeplot` where the maximum limit on the density axes was not being updated when multiple densities were drawn. - Improved compatibility with future versions of pandas. ================================================ FILE: doc/whatsnew/v0.9.0.rst ================================================ v0.9.0 (July 2018) ------------------ This is a major release with several substantial and long-desired new features. There are also updates/modifications to the themes and color palettes that give better consistency with matplotlib 2.0 and some notable API changes. New relational plots ~~~~~~~~~~~~~~~~~~~~ Three completely new plotting functions have been added: :func:`relplot`, :func:`scatterplot`, and :func:`lineplot`. The first is a figure-level interface to the latter two that combines them with a :class:`FacetGrid`. The functions bring the high-level, dataset-oriented API of the seaborn categorical plotting functions to more general plots (scatter plots and line plots). These functions can visualize a relationship between two numeric variables while mapping up to three additional variables by modifying ``hue``, ``size``, and/or ``style`` semantics. The common high-level API is implemented differently in the two functions. For example, the size semantic in :func:`scatterplot` scales the area of scatter plot points, but in :func:`lineplot` it scales width of the line plot lines. The API is dataset-oriented, meaning that in both cases you pass the variable in your dataset rather than directly specifying the matplotlib parameters to use for point area or line width. Another way the relational functions differ from existing seaborn functionality is that they have better support for using numeric variables for ``hue`` and ``size`` semantics. This functionality may be propagated to other functions that can add a ``hue`` semantic in future versions; it has not been in this release. The :func:`lineplot` function also has support for statistical estimation and is replacing the older ``tsplot`` function, which still exists but is marked for removal in a future release. :func:`lineplot` is better aligned with the API of the rest of the library and more flexible in showing relationships across additional variables by modifying the size and style semantics independently. It also has substantially improved support for date and time data, a major pain factor in ``tsplot``. The cost is that some of the more esoteric options in ``tsplot`` for representing uncertainty (e.g. a colormapped KDE of the bootstrap distribution) have not been implemented in the new function. There is quite a bit of new documentation that explains these new functions in more detail, including detailed examples of the various options in the :ref:`API reference ` and a more verbose :ref:`tutorial `. These functions should be considered in a "stable beta" state. They have been thoroughly tested, but some unknown corner cases may remain to be found. The main features are in place, but not all planned functionality has been implemented. There are planned improvements to some elements, particularly the default legend, that are a little rough around the edges in this release. Finally, some of the default behavior (e.g. the default range of point/line sizes) may change somewhat in future releases. Updates to themes and palettes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Several changes have been made to the seaborn style themes, context scaling, and color palettes. In general the aim of these changes was to make the seaborn styles more consistent with the `style updates in matplotlib 2.0 `_ and to leverage some of the new style parameters for better implementation of some aspects of the seaborn styles. Here is a list of the changes: - Reorganized and updated some :func:`axes_style`/:func:`plotting_context` parameters to take advantage of improvements in the matplotlib 2.0 update. The biggest change involves using several new parameters in the "style" spec while moving parameters that used to implement the corresponding aesthetics to the "context" spec. For example, axes spines and ticks are now off instead of having their width/length zeroed out for the darkgrid style. That means the width/length of these elements can now be scaled in different contexts. The effect is a more cohesive appearance of the plots, especially in larger contexts. These changes include only minimal support for the 1.x matplotlib series. Users who are stuck on matplotlib 1.5 but wish to use seaborn styling may want to use the seaborn parameters that can be accessed through the `matplotlib stylesheet interface `_. - Updated the seaborn palettes ("deep", "muted", "colorblind", etc.) to correspond with the new 10-color matplotlib default. The legacy palettes are now available at "deep6", "muted6", "colorblind6", etc. Additionally, a few individual colors were tweaked for better consistency, aesthetics, and accessibility. - Calling :func:`color_palette` (or :func:`set_palette`) with a named qualitative palettes (i.e. one of the seaborn palettes, the colorbrewer qualitative palettes, or the matplotlib matplotlib tableau-derived palettes) and no specified number of colors will return all of the colors in the palette. This means that for some palettes, the returned list will have a different length than it did in previous versions. - Enhanced :func:`color_palette` to accept a parameterized specification of a cubehelix palette in in a string, prefixed with ``"ch:"`` (e.g. ``"ch:-.1,.2,l=.7"``). Note that keyword arguments can be spelled out or referenced using only their first letter. Reversing the palette is accomplished by appending ``"_r"``, as with other matplotlib colormaps. This specification will be accepted by any seaborn function with a ``palette=`` parameter. - Slightly increased the base font sizes in :func:`plotting_context` and increased the scaling factors for ``"talk"`` and ``"poster"`` contexts. - Calling :func:`set` will now call :func:`set_color_codes` to re-assign the single letter color codes by default API changes ~~~~~~~~~~~ A few functions have been renamed or have had changes to their default parameters. - The ``factorplot`` function has been renamed to :func:`catplot`. The new name ditches the original R-inflected terminology to use a name that is more consistent with terminology in pandas and in seaborn itself. This change should hopefully make :func:`catplot` easier to discover, and it should make more clear what its role is. ``factorplot`` still exists and will pass its arguments through to :func:`catplot` with a warning. It may be removed eventually, but the transition will be as gradual as possible. - The other reason that the ``factorplot`` name was changed was to ease another alteration which is that the default ``kind`` in :func:`catplot` is now ``"strip"`` (corresponding to :func:`stripplot`). This plots a categorical scatter plot which is usually a much better place to start and is more consistent with the default in :func:`relplot`. The old default style in ``factorplot`` (``"point"``, corresponding to :func:`pointplot`) remains available if you want to show a statistical estimation. - The ``lvplot`` function has been renamed to :func:`boxenplot`. The "letter-value" terminology that was used to name the original kind of plot is obscure, and the abbreviation to ``lv`` did not help anything. The new name should make the plot more discoverable by describing its format (it plots multiple boxes, also known as "boxen"). As with ``factorplot``, the ``lvplot`` function still exists to provide a relatively smooth transition. - Renamed the ``size`` parameter to ``height`` in multi-plot grid objects (:class:`FacetGrid`, :class:`PairGrid`, and :class:`JointGrid`) along with functions that use them (``factorplot``, :func:`lmplot`, :func:`pairplot`, and :func:`jointplot`) to avoid conflicts with the ``size`` parameter that is used in ``scatterplot`` and ``lineplot`` (necessary to make :func:`relplot` work) and also makes the meaning of the parameter a bit more clear. - Changed the default diagonal plots in :func:`pairplot` to use func:`kdeplot` when a ``"hue"`` dimension is used. - Deprecated the statistical annotation component of :class:`JointGrid`. The method is still available but will be removed in a future version. - Two older functions that were deprecated in earlier versions, ``coefplot`` and ``interactplot``, have undergone final removal from the code base. Documentation improvements ~~~~~~~~~~~~~~~~~~~~~~~~~~ There has been some effort put into improving the documentation. The biggest change is that the :ref:`introduction to the library ` has been completely rewritten to provide much more information and, critically, examples. In addition to the high-level motivation, the introduction also covers some important topics that are often sources of confusion, like the distinction between figure-level and axes-level functions, how datasets should be formatted for use in seaborn, and how to customize the appearance of the plots. Other improvements have been made throughout, most notably a thorough re-write of the :ref:`categorical tutorial `. Other small enhancements and bug fixes ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Changed :func:`rugplot` to plot a matplotlib ``LineCollection`` instead of many ``Line2D`` objects, providing a big speedup for large arrays. - Changed the default off-diagonal plots to use :func:`scatterplot`. (Note that the ``"hue"`` currently draws three separate scatterplots instead of using the hue semantic of the scatterplot function). - Changed color handling when using :func:`kdeplot` with two variables. The default colormap for the 2D density now follows the color cycle, and the function can use ``color`` and ``label`` kwargs, adding more flexibility and avoiding a warning when using with multi-plot grids. - Added the ``subplot_kws`` parameter to :class:`PairGrid` for more flexibility. - Removed a special case in :class:`PairGrid` that defaulted to drawing stacked histograms on the diagonal axes. - Fixed :func:`jointplot`/:class:`JointGrid` and :func:`regplot` so that they now accept list inputs. - Fixed a bug in :class:`FacetGrid` when using a single row/column level or using ``col_wrap=1``. - Fixed functions that set axis limits so that they preserve auto-scaling state on matplotlib 2.0. - Avoided an error when using matplotlib backends that cannot render a canvas (e.g. PDF). - Changed the install infrastructure to explicitly declare dependencies in a way that ``pip`` is aware of. This means that ``pip install seaborn`` will now work in an empty environment. Additionally, the dependencies are specified with strict minimal versions. - Updated the testing infrastructure to execute tests with `pytest `_ (although many individual tests still use nose assertion). ================================================ FILE: doc/whatsnew/v0.9.1.rst ================================================ v0.9.1 (January 2020) --------------------- This is a minor release with a number of bug fixes and adaptations to changes in seaborn's dependencies. There are also several new features. This is the final version of seaborn that will support Python 2.7 or 3.5. New features ~~~~~~~~~~~~ - Added more control over the arrangement of the elements drawn by :func:`clustermap` with the ``{dendrogram,colors}_ratio`` and ``cbar_pos`` parameters. Additionally, the default organization and scaling with different figure sizes has been improved. - Added the ``corner`` option to :class:`PairGrid` and :func:`pairplot` to make a grid without the upper triangle of bivariate axes. - Added the ability to seed the random number generator for the bootstrap used to define error bars in several plots. Relevant functions now have a ``seed`` parameter, which can take either fixed seed (typically an ``int``) or a numpy random number generator object (either the newer :class:`numpy.random.Generator` or the older :class:`numpy.random.mtrand.RandomState`). - Generalized the idea of "diagonal" axes in :class:`PairGrid` to any axes that share an x and y variable. - In :class:`PairGrid`, the ``hue`` variable is now excluded from the default list of variables that make up the rows and columns of the grid. - Exposed the ``layout_pad`` parameter in :class:`PairGrid` and set a smaller default than what matptlotlib sets for more efficient use of space in dense grids. - It is now possible to force a categorical interpretation of the ``hue`` variable in a relational plot by passing the name of a categorical palette (e.g. ``"deep"``, or ``"Set2"``). This complements the (previously supported) option of passing a list/dict of colors. - Added the ``tree_kws`` parameter to :func:`clustermap` to control the properties of the lines in the dendrogram. - Added the ability to pass hierarchical label names to the :class:`FacetGrid` legend, which also fixes a bug in :func:`relplot` when the same label appeared in different semantics. - Improved support for grouping observations based on pandas index information in categorical plots. Bug fixes and adaptations ~~~~~~~~~~~~~~~~~~~~~~~~~ - Avoided an error when singular data is passed to :func:`kdeplot`, issuing a warning instead. This makes :func:`pairplot` more robust. - Fixed the behavior of ``dropna`` in :class:`PairGrid` to properly exclude null datapoints from each plot when set to ``True``. - Fixed an issue where :func:`regplot` could interfere with other axes in a multi-plot matplotlib figure. - Semantic variables with a ``category`` data type will always be treated as categorical in relational plots. - Avoided a warning about color specifications that arose from :func:`boxenplot` on newer matplotlibs. - Adapted to a change in how matplotlib scales axis margins, which caused multiple calls to :func:`regplot` with ``truncate=False`` to progressively expand the x axis limits. Because there are currently limitations on how autoscaling works in matplotlib, the default value for ``truncate`` in seaborn has also been changed to ``True``. - Relational plots no longer error when hue/size data are inferred to be numeric but stored with a string datatype. - Relational plots now consider semantics with only a single value that can be interpreted as boolean (0 or 1) to be categorical, not numeric. - Relational plots now handle list or dict specifications for ``sizes`` correctly. - Fixed an issue in :func:`pointplot` where missing levels of a hue variable would cause an exception after a recent update in matplotlib. - Fixed a bug when setting the rotation of x tick labels on a :class:`FacetGrid`. - Fixed a bug where values would be excluded from categorical plots when only one variable was a pandas ``Series`` with a non-default index. - Fixed a bug when using ``Series`` objects as arguments for ``x_partial`` or ``y_partial`` in :func:`regplot`. - Fixed a bug when passing a ``norm`` object and using color annotations in :func:`clustermap`. - Fixed a bug where annotations were not rearranged to match the clustering in :func:`clustermap`. - Fixed a bug when trying to call :func:`set` while specifying a list of colors for the palette. - Fixed a bug when resetting the color code short-hands to the matplotlib default. - Avoided errors from stricter type checking in upcoming ``numpy`` changes. - Avoided error/warning in :func:`lineplot` when plotting categoricals with empty levels. - Allowed ``colors`` to be passed through to a bivariate :func:`kdeplot`. - Standardized the output format of custom color palette functions. - Fixed a bug where legends for numerical variables in a relational plot could show a surprisingly large number of decimal places. - Improved robustness to missing values in distribution plots. - Made it possible to specify the location of the :class:`FacetGrid` legend using matplotlib keyword arguments. ================================================ FILE: examples/.gitignore ================================================ *.html *_files/ ================================================ FILE: examples/anscombes_quartet.py ================================================ """ Anscombe's quartet ================== _thumb: .4, .4 """ import seaborn as sns sns.set_theme(style="ticks") # Load the example dataset for Anscombe's quartet df = sns.load_dataset("anscombe") # Show the results of a linear regression within each dataset sns.lmplot( data=df, x="x", y="y", col="dataset", hue="dataset", col_wrap=2, palette="muted", ci=None, height=4, scatter_kws={"s": 50, "alpha": 1} ) ================================================ FILE: examples/different_scatter_variables.py ================================================ """ Scatterplot with multiple semantics =================================== _thumb: .45, .5 """ import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="whitegrid") # Load the example diamonds dataset diamonds = sns.load_dataset("diamonds") # Draw a scatter plot while assigning point colors and sizes to different # variables in the dataset f, ax = plt.subplots(figsize=(6.5, 6.5)) sns.despine(f, left=True, bottom=True) clarity_ranking = ["I1", "SI2", "SI1", "VS2", "VS1", "VVS2", "VVS1", "IF"] sns.scatterplot(x="carat", y="price", hue="clarity", size="depth", palette="ch:r=-.2,d=.3_r", hue_order=clarity_ranking, sizes=(1, 8), linewidth=0, data=diamonds, ax=ax) ================================================ FILE: examples/errorband_lineplots.py ================================================ """ Timeseries plot with error bands ================================ _thumb: .48, .45 """ import seaborn as sns sns.set_theme(style="darkgrid") # Load an example dataset with long-form data fmri = sns.load_dataset("fmri") # Plot the responses for different events and regions sns.lineplot(x="timepoint", y="signal", hue="region", style="event", data=fmri) ================================================ FILE: examples/faceted_histogram.py ================================================ """ Facetting histograms by subsets of data ======================================= _thumb: .33, .57 """ import seaborn as sns sns.set_theme(style="darkgrid") df = sns.load_dataset("penguins") sns.displot( df, x="flipper_length_mm", col="species", row="sex", binwidth=3, height=3, facet_kws=dict(margin_titles=True), ) ================================================ FILE: examples/faceted_lineplot.py ================================================ """ Line plots on multiple facets ============================= _thumb: .48, .42 """ import seaborn as sns sns.set_theme(style="ticks") dots = sns.load_dataset("dots") # Define the palette as a list to specify exact values palette = sns.color_palette("rocket_r") # Plot the lines on two facets sns.relplot( data=dots, x="time", y="firing_rate", hue="coherence", size="choice", col="align", kind="line", size_order=["T1", "T2"], palette=palette, height=5, aspect=.75, facet_kws=dict(sharex=False), ) ================================================ FILE: examples/grouped_barplot.py ================================================ """ Grouped barplots ================ _thumb: .36, .5 """ import seaborn as sns sns.set_theme(style="whitegrid") penguins = sns.load_dataset("penguins") # Draw a nested barplot by species and sex g = sns.catplot( data=penguins, kind="bar", x="species", y="body_mass_g", hue="sex", errorbar="sd", palette="dark", alpha=.6, height=6 ) g.despine(left=True) g.set_axis_labels("", "Body mass (g)") g.legend.set_title("") ================================================ FILE: examples/grouped_boxplot.py ================================================ """ Grouped boxplots ================ _thumb: .66, .45 """ import seaborn as sns sns.set_theme(style="ticks", palette="pastel") # Load the example tips dataset tips = sns.load_dataset("tips") # Draw a nested boxplot to show bills by day and time sns.boxplot(x="day", y="total_bill", hue="smoker", palette=["m", "g"], data=tips) sns.despine(offset=10, trim=True) ================================================ FILE: examples/grouped_violinplots.py ================================================ """ Grouped violinplots with split violins ====================================== _thumb: .44, .47 """ import seaborn as sns sns.set_theme(style="dark") # Load the example tips dataset tips = sns.load_dataset("tips") # Draw a nested violinplot and split the violins for easier comparison sns.violinplot(data=tips, x="day", y="total_bill", hue="smoker", split=True, inner="quart", fill=False, palette={"Yes": "g", "No": ".35"}) ================================================ FILE: examples/heat_scatter.py ================================================ """ Scatterplot heatmap ------------------- _thumb: .5, .5 """ import seaborn as sns sns.set_theme(style="whitegrid") # Load the brain networks dataset, select subset, and collapse the multi-index df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0) used_networks = [1, 5, 6, 7, 8, 12, 13, 17] used_columns = (df.columns .get_level_values("network") .astype(int) .isin(used_networks)) df = df.loc[:, used_columns] df.columns = df.columns.map("-".join) # Compute a correlation matrix and convert to long-form corr_mat = df.corr().stack().reset_index(name="correlation") # Draw each cell as a scatter point with varying size and color g = sns.relplot( data=corr_mat, x="level_0", y="level_1", hue="correlation", size="correlation", palette="vlag", hue_norm=(-1, 1), edgecolor=".7", height=10, sizes=(50, 250), size_norm=(-.2, .8), ) # Tweak the figure to finalize g.set(xlabel="", ylabel="", aspect="equal") g.despine(left=True, bottom=True) g.ax.margins(.02) for label in g.ax.get_xticklabels(): label.set_rotation(90) ================================================ FILE: examples/hexbin_marginals.py ================================================ """ Hexbin plot with marginal distributions ======================================= _thumb: .45, .4 """ import numpy as np import seaborn as sns sns.set_theme(style="ticks") rs = np.random.RandomState(11) x = rs.gamma(2, size=1000) y = -.5 * x + rs.normal(size=1000) sns.jointplot(x=x, y=y, kind="hex", color="#4CB391") ================================================ FILE: examples/histogram_stacked.py ================================================ """ Stacked histogram on a log scale ================================ _thumb: .5, .45 """ import seaborn as sns import matplotlib as mpl import matplotlib.pyplot as plt sns.set_theme(style="ticks") diamonds = sns.load_dataset("diamonds") f, ax = plt.subplots(figsize=(7, 5)) sns.despine(f) sns.histplot( diamonds, x="price", hue="cut", multiple="stack", palette="light:m_r", edgecolor=".3", linewidth=.5, log_scale=True, ) ax.xaxis.set_major_formatter(mpl.ticker.ScalarFormatter()) ax.set_xticks([500, 1000, 2000, 5000, 10000]) ================================================ FILE: examples/horizontal_boxplot.py ================================================ """ Horizontal boxplot with observations ==================================== _thumb: .7, .37 """ import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="ticks") # Initialize the figure with a logarithmic x axis f, ax = plt.subplots(figsize=(7, 6)) ax.set_xscale("log") # Load the example planets dataset planets = sns.load_dataset("planets") # Plot the orbital period with horizontal boxes sns.boxplot( planets, x="distance", y="method", hue="method", whis=[0, 100], width=.6, palette="vlag" ) # Add in points to show each observation sns.stripplot(planets, x="distance", y="method", size=4, color=".3") # Tweak the visual presentation ax.xaxis.grid(True) ax.set(ylabel="") sns.despine(trim=True, left=True) ================================================ FILE: examples/jitter_stripplot.py ================================================ """ Conditional means with observations =================================== """ import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="whitegrid") iris = sns.load_dataset("iris") # "Melt" the dataset to "long-form" or "tidy" representation iris = iris.melt(id_vars="species", var_name="measurement") # Initialize the figure f, ax = plt.subplots() sns.despine(bottom=True, left=True) # Show each observation with a scatterplot sns.stripplot( data=iris, x="value", y="measurement", hue="species", dodge=True, alpha=.25, zorder=1, legend=False, ) # Show the conditional means, aligning each pointplot in the # center of the strips by adjusting the width allotted to each # category (.8 by default) by the number of hue levels sns.pointplot( data=iris, x="value", y="measurement", hue="species", dodge=.8 - .8 / 3, palette="dark", errorbar=None, markers="d", markersize=4, linestyle="none", ) # Improve the legend sns.move_legend( ax, loc="lower right", ncol=3, frameon=True, columnspacing=1, handletextpad=0, ) ================================================ FILE: examples/joint_histogram.py ================================================ """ Joint and marginal histograms ============================= _thumb: .52, .505 """ import seaborn as sns sns.set_theme(style="ticks") # Load the planets dataset and initialize the figure planets = sns.load_dataset("planets") g = sns.JointGrid(data=planets, x="year", y="distance", marginal_ticks=True) # Set a log scaling on the y axis g.ax_joint.set(yscale="log") # Create an inset legend for the histogram colorbar cax = g.figure.add_axes([.15, .55, .02, .2]) # Add the joint and marginal histogram plots g.plot_joint( sns.histplot, discrete=(True, False), cmap="light:#03012d", pmax=.8, cbar=True, cbar_ax=cax ) g.plot_marginals(sns.histplot, element="step", color="#03012d") ================================================ FILE: examples/joint_kde.py ================================================ """ Joint kernel density estimate ============================= _thumb: .6, .4 """ import seaborn as sns sns.set_theme(style="ticks") # Load the penguins dataset penguins = sns.load_dataset("penguins") # Show the joint distribution using kernel density estimation g = sns.jointplot( data=penguins, x="bill_length_mm", y="bill_depth_mm", hue="species", kind="kde", ) ================================================ FILE: examples/kde_ridgeplot.py ================================================ """ Overlapping densities ('ridge plot') ==================================== """ import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)}) # Create the data rs = np.random.RandomState(1979) x = rs.randn(500) g = np.tile(list("ABCDEFGHIJ"), 50) df = pd.DataFrame(dict(x=x, g=g)) m = df.g.map(ord) df["x"] += m # Initialize the FacetGrid object pal = sns.cubehelix_palette(10, rot=-.25, light=.7) g = sns.FacetGrid(df, row="g", hue="g", aspect=15, height=.5, palette=pal) # Draw the densities in a few steps g.map(sns.kdeplot, "x", bw_adjust=.5, clip_on=False, fill=True, alpha=1, linewidth=1.5) g.map(sns.kdeplot, "x", clip_on=False, color="w", lw=2, bw_adjust=.5) # passing color=None to refline() uses the hue mapping g.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False) # Define and use a simple function to label the plot in axes coordinates def label(x, color, label): ax = plt.gca() ax.text(0, .2, label, fontweight="bold", color=color, ha="left", va="center", transform=ax.transAxes) g.map(label, "x") # Set the subplots to overlap g.figure.subplots_adjust(hspace=-.25) # Remove axes details that don't play well with overlap g.set_titles("") g.set(yticks=[], ylabel="") g.despine(bottom=True, left=True) ================================================ FILE: examples/large_distributions.py ================================================ """ Plotting large distributions ============================ """ import seaborn as sns sns.set_theme(style="whitegrid") diamonds = sns.load_dataset("diamonds") clarity_ranking = ["I1", "SI2", "SI1", "VS2", "VS1", "VVS2", "VVS1", "IF"] sns.boxenplot( diamonds, x="clarity", y="carat", color="b", order=clarity_ranking, width_method="linear", ) ================================================ FILE: examples/layered_bivariate_plot.py ================================================ """ Bivariate plot with multiple elements ===================================== """ import numpy as np import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="dark") # Simulate data from a bivariate Gaussian n = 10000 mean = [0, 0] cov = [(2, .4), (.4, .2)] rng = np.random.RandomState(0) x, y = rng.multivariate_normal(mean, cov, n).T # Draw a combo histogram and scatterplot with density contours f, ax = plt.subplots(figsize=(6, 6)) sns.scatterplot(x=x, y=y, s=5, color=".15") sns.histplot(x=x, y=y, bins=50, pthresh=.1, cmap="mako") sns.kdeplot(x=x, y=y, levels=5, color="w", linewidths=1) ================================================ FILE: examples/logistic_regression.py ================================================ """ Faceted logistic regression =========================== _thumb: .58, .5 """ import seaborn as sns sns.set_theme(style="darkgrid") # Load the example Titanic dataset df = sns.load_dataset("titanic") # Make a custom palette with gendered colors pal = dict(male="#6495ED", female="#F08080") # Show the survival probability as a function of age and sex g = sns.lmplot(x="age", y="survived", col="sex", hue="sex", data=df, palette=pal, y_jitter=.02, logistic=True, truncate=False) g.set(xlim=(0, 80), ylim=(-.05, 1.05)) ================================================ FILE: examples/many_facets.py ================================================ """ Plotting on a large number of facets ==================================== _thumb: .4, .3 """ import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="ticks") # Create a dataset with many short random walks rs = np.random.RandomState(4) pos = rs.randint(-1, 2, (20, 5)).cumsum(axis=1) pos -= pos[:, 0, np.newaxis] step = np.tile(range(5), 20) walk = np.repeat(range(20), 5) df = pd.DataFrame(np.c_[pos.flat, step, walk], columns=["position", "step", "walk"]) # Initialize a grid of plots with an Axes for each walk grid = sns.FacetGrid(df, col="walk", hue="walk", palette="tab20c", col_wrap=4, height=1.5) # Draw a horizontal line to show the starting point grid.refline(y=0, linestyle=":") # Draw a line plot to show the trajectory of each random walk grid.map(plt.plot, "step", "position", marker="o") # Adjust the tick positions and labels grid.set(xticks=np.arange(5), yticks=[-3, 3], xlim=(-.5, 4.5), ylim=(-3.5, 3.5)) # Adjust the arrangement of the plots grid.fig.tight_layout(w_pad=1) ================================================ FILE: examples/many_pairwise_correlations.py ================================================ """ Plotting a diagonal correlation matrix ====================================== _thumb: .3, .6 """ from string import ascii_letters import numpy as np import pandas as pd import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="white") # Generate a large random dataset rs = np.random.RandomState(33) d = pd.DataFrame(data=rs.normal(size=(100, 26)), columns=list(ascii_letters[26:])) # Compute the correlation matrix corr = d.corr() # Generate a mask for the upper triangle mask = np.triu(np.ones_like(corr, dtype=bool)) # Set up the matplotlib figure f, ax = plt.subplots(figsize=(11, 9)) # Generate a custom diverging colormap cmap = sns.diverging_palette(230, 20, as_cmap=True) # Draw the heatmap with the mask and correct aspect ratio sns.heatmap(corr, mask=mask, cmap=cmap, vmax=.3, center=0, square=True, linewidths=.5, cbar_kws={"shrink": .5}) ================================================ FILE: examples/marginal_ticks.py ================================================ """ Scatterplot with marginal ticks =============================== _thumb: .66, .34 """ import seaborn as sns sns.set_theme(style="white", color_codes=True) mpg = sns.load_dataset("mpg") # Use JointGrid directly to draw a custom plot g = sns.JointGrid(data=mpg, x="mpg", y="acceleration", space=0, ratio=17) g.plot_joint(sns.scatterplot, size=mpg["horsepower"], sizes=(30, 120), color="g", alpha=.6, legend=False) g.plot_marginals(sns.rugplot, height=1, color="g", alpha=.6) ================================================ FILE: examples/multiple_bivariate_kde.py ================================================ """ Multiple bivariate KDE plots ============================ _thumb: .6, .45 """ import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="darkgrid") iris = sns.load_dataset("iris") # Set up the figure f, ax = plt.subplots(figsize=(8, 8)) ax.set_aspect("equal") # Draw a contour plot to represent each bivariate density sns.kdeplot( data=iris.query("species != 'versicolor'"), x="sepal_width", y="sepal_length", hue="species", thresh=.1, ) ================================================ FILE: examples/multiple_conditional_kde.py ================================================ """ Conditional kernel density estimate =================================== _thumb: .4, .5 """ import seaborn as sns sns.set_theme(style="whitegrid") # Load the diamonds dataset diamonds = sns.load_dataset("diamonds") # Plot the distribution of clarity ratings, conditional on carat sns.displot( data=diamonds, x="carat", hue="cut", kind="kde", height=6, multiple="fill", clip=(0, None), palette="ch:rot=-.25,hue=1,light=.75", ) ================================================ FILE: examples/multiple_ecdf.py ================================================ """ Facetted ECDF plots =================== _thumb: .30, .49 """ import seaborn as sns sns.set_theme(style="ticks") mpg = sns.load_dataset("mpg") colors = (250, 70, 50), (350, 70, 50) cmap = sns.blend_palette(colors, input="husl", as_cmap=True) sns.displot( mpg, x="displacement", col="origin", hue="model_year", kind="ecdf", aspect=.75, linewidth=2, palette=cmap, ) ================================================ FILE: examples/multiple_regression.py ================================================ """ Multiple linear regression ========================== _thumb: .45, .45 """ import seaborn as sns sns.set_theme() # Load the penguins dataset penguins = sns.load_dataset("penguins") # Plot sepal width as a function of sepal_length across days g = sns.lmplot( data=penguins, x="bill_length_mm", y="bill_depth_mm", hue="species", height=5 ) # Use more informative axis labels than are provided by default g.set_axis_labels("Snoot length (mm)", "Snoot depth (mm)") ================================================ FILE: examples/pair_grid_with_kde.py ================================================ """ Paired density and scatterplot matrix ===================================== _thumb: .5, .5 """ import seaborn as sns sns.set_theme(style="white") df = sns.load_dataset("penguins") g = sns.PairGrid(df, diag_sharey=False) g.map_upper(sns.scatterplot, s=15) g.map_lower(sns.kdeplot) g.map_diag(sns.kdeplot, lw=2) ================================================ FILE: examples/paired_pointplots.py ================================================ """ Paired categorical plots ======================== """ import seaborn as sns sns.set_theme(style="whitegrid") # Load the example Titanic dataset titanic = sns.load_dataset("titanic") # Set up a grid to plot survival probability against several variables g = sns.PairGrid(titanic, y_vars="survived", x_vars=["class", "sex", "who", "alone"], height=5, aspect=.5) # Draw a seaborn pointplot onto each Axes g.map(sns.pointplot, color="xkcd:plum") g.set(ylim=(0, 1)) sns.despine(fig=g.fig, left=True) ================================================ FILE: examples/pairgrid_dotplot.py ================================================ """ Dot plot with several variables =============================== _thumb: .3, .3 """ import seaborn as sns sns.set_theme(style="whitegrid") # Load the dataset crashes = sns.load_dataset("car_crashes") # Make the PairGrid g = sns.PairGrid(crashes.sort_values("total", ascending=False), x_vars=crashes.columns[:-3], y_vars=["abbrev"], height=10, aspect=.25) # Draw a dot plot using the stripplot function g.map(sns.stripplot, size=10, orient="h", jitter=False, palette="flare_r", linewidth=1, edgecolor="w") # Use the same x axis limits on all columns and add better labels g.set(xlim=(0, 25), xlabel="Crashes", ylabel="") # Use semantically meaningful titles for the columns titles = ["Total crashes", "Speeding crashes", "Alcohol crashes", "Not distracted crashes", "No previous crashes"] for ax, title in zip(g.axes.flat, titles): # Set a different title for each axes ax.set(title=title) # Make the grid horizontal instead of vertical ax.xaxis.grid(False) ax.yaxis.grid(True) sns.despine(left=True, bottom=True) ================================================ FILE: examples/palette_choices.py ================================================ """ Color palette choices ===================== """ import numpy as np import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="white", context="talk") rs = np.random.RandomState(8) # Set up the matplotlib figure f, (ax1, ax2, ax3) = plt.subplots(3, 1, figsize=(7, 5), sharex=True) # Generate some sequential data x = np.array(list("ABCDEFGHIJ")) y1 = np.arange(1, 11) sns.barplot(x=x, y=y1, hue=x, palette="rocket", ax=ax1) ax1.axhline(0, color="k", clip_on=False) ax1.set_ylabel("Sequential") # Center the data to make it diverging y2 = y1 - 5.5 sns.barplot(x=x, y=y2, hue=x, palette="vlag", ax=ax2) ax2.axhline(0, color="k", clip_on=False) ax2.set_ylabel("Diverging") # Randomly reorder the data to make it qualitative y3 = rs.choice(y1, len(y1), replace=False) sns.barplot(x=x, y=y3, hue=x, palette="deep", ax=ax3) ax3.axhline(0, color="k", clip_on=False) ax3.set_ylabel("Qualitative") # Finalize the plot sns.despine(bottom=True) plt.setp(f.axes, yticks=[]) plt.tight_layout(h_pad=2) ================================================ FILE: examples/palette_generation.py ================================================ """ Different cubehelix palettes ============================ _thumb: .4, .65 """ import numpy as np import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="white") rs = np.random.RandomState(50) # Set up the matplotlib figure f, axes = plt.subplots(3, 3, figsize=(9, 9), sharex=True, sharey=True) # Rotate the starting point around the cubehelix hue circle for ax, s in zip(axes.flat, np.linspace(0, 3, 10)): # Create a cubehelix colormap to use with kdeplot cmap = sns.cubehelix_palette(start=s, light=1, as_cmap=True) # Generate and plot a random bivariate dataset x, y = rs.normal(size=(2, 50)) sns.kdeplot( x=x, y=y, cmap=cmap, fill=True, clip=(-5, 5), cut=10, thresh=0, levels=15, ax=ax, ) ax.set_axis_off() ax.set(xlim=(-3.5, 3.5), ylim=(-3.5, 3.5)) f.subplots_adjust(0, 0, 1, 1, .08, .08) ================================================ FILE: examples/part_whole_bars.py ================================================ """ Horizontal bar plots ==================== """ import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="whitegrid") # Initialize the matplotlib figure f, ax = plt.subplots(figsize=(6, 15)) # Load the example car crash dataset crashes = sns.load_dataset("car_crashes").sort_values("total", ascending=False) # Plot the total crashes sns.set_color_codes("pastel") sns.barplot(x="total", y="abbrev", data=crashes, label="Total", color="b") # Plot the crashes where alcohol was involved sns.set_color_codes("muted") sns.barplot(x="alcohol", y="abbrev", data=crashes, label="Alcohol-involved", color="b") # Add a legend and informative axis label ax.legend(ncol=2, loc="lower right", frameon=True) ax.set(xlim=(0, 24), ylabel="", xlabel="Automobile collisions per billion miles") sns.despine(left=True, bottom=True) ================================================ FILE: examples/pointplot_anova.py ================================================ """ Plotting a three-way ANOVA ========================== _thumb: .42, .5 """ import seaborn as sns sns.set_theme(style="whitegrid") # Load the example exercise dataset exercise = sns.load_dataset("exercise") # Draw a pointplot to show pulse as a function of three categorical factors g = sns.catplot( data=exercise, x="time", y="pulse", hue="kind", col="diet", capsize=.2, palette="YlGnBu_d", errorbar="se", kind="point", height=6, aspect=.75, ) g.despine(left=True) ================================================ FILE: examples/radial_facets.py ================================================ """ FacetGrid with custom projection ================================ _thumb: .33, .5 """ import numpy as np import pandas as pd import seaborn as sns sns.set_theme() # Generate an example radial datast r = np.linspace(0, 10, num=100) df = pd.DataFrame({'r': r, 'slow': r, 'medium': 2 * r, 'fast': 4 * r}) # Convert the dataframe to long-form or "tidy" format df = pd.melt(df, id_vars=['r'], var_name='speed', value_name='theta') # Set up a grid of axes with a polar projection g = sns.FacetGrid(df, col="speed", hue="speed", subplot_kws=dict(projection='polar'), height=4.5, sharex=False, sharey=False, despine=False) # Draw a scatterplot onto each axes in the grid g.map(sns.scatterplot, "theta", "r") ================================================ FILE: examples/regression_marginals.py ================================================ """ Linear regression with marginal distributions ============================================= _thumb: .65, .65 """ import seaborn as sns sns.set_theme(style="darkgrid") tips = sns.load_dataset("tips") g = sns.jointplot(x="total_bill", y="tip", data=tips, kind="reg", truncate=False, xlim=(0, 60), ylim=(0, 12), color="m", height=7) ================================================ FILE: examples/residplot.py ================================================ """ Plotting model residuals ======================== """ import numpy as np import seaborn as sns sns.set_theme(style="whitegrid") # Make an example dataset with y ~ x rs = np.random.RandomState(7) x = rs.normal(2, 1, 75) y = 2 + 1.5 * x + rs.normal(0, 2, 75) # Plot the residuals after fitting a linear model sns.residplot(x=x, y=y, lowess=True, color="g") ================================================ FILE: examples/scatter_bubbles.py ================================================ """ Scatterplot with varying point sizes and hues ============================================== _thumb: .45, .5 """ import seaborn as sns sns.set_theme(style="white") # Load the example mpg dataset mpg = sns.load_dataset("mpg") # Plot miles per gallon against horsepower with other semantics sns.relplot(x="horsepower", y="mpg", hue="origin", size="weight", sizes=(40, 400), alpha=.5, palette="muted", height=6, data=mpg) ================================================ FILE: examples/scatterplot_categorical.py ================================================ """ Scatterplot with categorical variables ====================================== _thumb: .45, .45 """ import seaborn as sns sns.set_theme(style="whitegrid", palette="muted") # Load the penguins dataset df = sns.load_dataset("penguins") # Draw a categorical scatterplot to show each observation ax = sns.swarmplot(data=df, x="body_mass_g", y="sex", hue="species") ax.set(ylabel="") ================================================ FILE: examples/scatterplot_matrix.py ================================================ """ Scatterplot Matrix ================== _thumb: .3, .2 """ import seaborn as sns sns.set_theme(style="ticks") df = sns.load_dataset("penguins") sns.pairplot(df, hue="species") ================================================ FILE: examples/scatterplot_sizes.py ================================================ """ Scatterplot with continuous hues and sizes ========================================== _thumb: .51, .44 """ import seaborn as sns sns.set_theme(style="whitegrid") # Load the example planets dataset planets = sns.load_dataset("planets") cmap = sns.cubehelix_palette(rot=-.2, as_cmap=True) g = sns.relplot( data=planets, x="distance", y="orbital_period", hue="year", size="mass", palette=cmap, sizes=(10, 200), ) g.set(xscale="log", yscale="log") g.ax.xaxis.grid(True, "minor", linewidth=.25) g.ax.yaxis.grid(True, "minor", linewidth=.25) g.despine(left=True, bottom=True) ================================================ FILE: examples/simple_violinplots.py ================================================ """ Horizontal, unfilled violinplots ================================ _thumb: .5, .45 """ import seaborn as sns sns.set_theme() seaice = sns.load_dataset("seaice") seaice["Decade"] = seaice["Date"].dt.year.round(-1) sns.violinplot(seaice, x="Extent", y="Decade", orient="y", fill=False) ================================================ FILE: examples/smooth_bivariate_kde.py ================================================ """ Smooth kernel density with marginal histograms ============================================== _thumb: .48, .41 """ import seaborn as sns sns.set_theme(style="white") df = sns.load_dataset("penguins") g = sns.JointGrid(data=df, x="body_mass_g", y="bill_depth_mm", space=0) g.plot_joint(sns.kdeplot, fill=True, clip=((2200, 6800), (10, 25)), thresh=0, levels=100, cmap="rocket") g.plot_marginals(sns.histplot, color="#03051A", alpha=1, bins=25) ================================================ FILE: examples/spreadsheet_heatmap.py ================================================ """ Annotated heatmaps ================== """ import matplotlib.pyplot as plt import seaborn as sns sns.set_theme() # Load the example flights dataset and convert to long-form flights_long = sns.load_dataset("flights") flights = ( flights_long .pivot(index="month", columns="year", values="passengers") ) # Draw a heatmap with the numeric values in each cell f, ax = plt.subplots(figsize=(9, 6)) sns.heatmap(flights, annot=True, fmt="d", linewidths=.5, ax=ax) ================================================ FILE: examples/strip_regplot.py ================================================ """ Regression fit over a strip plot ================================ _thumb: .53, .5 """ import seaborn as sns sns.set_theme() mpg = sns.load_dataset("mpg") sns.catplot( data=mpg, x="cylinders", y="acceleration", hue="weight", native_scale=True, zorder=1 ) sns.regplot( data=mpg, x="cylinders", y="acceleration", scatter=False, truncate=False, order=2, color=".2", ) ================================================ FILE: examples/structured_heatmap.py ================================================ """ Discovering structure in heatmap data ===================================== _thumb: .3, .25 """ import pandas as pd import seaborn as sns sns.set_theme() # Load the brain networks example dataset df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0) # Select a subset of the networks used_networks = [1, 5, 6, 7, 8, 12, 13, 17] used_columns = (df.columns.get_level_values("network") .astype(int) .isin(used_networks)) df = df.loc[:, used_columns] # Create a categorical palette to identify the networks network_pal = sns.husl_palette(8, s=.45) network_lut = dict(zip(map(str, used_networks), network_pal)) # Convert the palette to vectors that will be drawn on the side of the matrix networks = df.columns.get_level_values("network") network_colors = pd.Series(networks, index=df.columns).map(network_lut) # Draw the full plot g = sns.clustermap(df.corr(), center=0, cmap="vlag", row_colors=network_colors, col_colors=network_colors, dendrogram_ratio=(.1, .2), cbar_pos=(.02, .32, .03, .2), linewidths=.75, figsize=(12, 13)) g.ax_row_dendrogram.remove() ================================================ FILE: examples/three_variable_histogram.py ================================================ """ Trivariate histogram with two categorical variables =================================================== _thumb: .32, .55 """ import seaborn as sns sns.set_theme(style="dark") diamonds = sns.load_dataset("diamonds") sns.displot( data=diamonds, x="price", y="color", col="clarity", log_scale=(True, False), col_wrap=4, height=4, aspect=.7, ) ================================================ FILE: examples/timeseries_facets.py ================================================ """ Small multiple time series -------------------------- _thumb: .42, .58 """ import seaborn as sns sns.set_theme(style="dark") flights = sns.load_dataset("flights") # Plot each year's time series in its own facet g = sns.relplot( data=flights, x="month", y="passengers", col="year", hue="year", kind="line", palette="crest", linewidth=4, zorder=5, col_wrap=3, height=2, aspect=1.5, legend=False, ) # Iterate over each subplot to customize further for year, ax in g.axes_dict.items(): # Add the title as an annotation within the plot ax.text(.8, .85, year, transform=ax.transAxes, fontweight="bold") # Plot every year's time series in the background sns.lineplot( data=flights, x="month", y="passengers", units="year", estimator=None, color=".7", linewidth=1, ax=ax, ) # Reduce the frequency of the x axis ticks ax.set_xticks(ax.get_xticks()[::2]) # Tweak the supporting aspects of the plot g.set_titles("") g.set_axis_labels("", "Passengers") g.tight_layout() ================================================ FILE: examples/wide_data_lineplot.py ================================================ """ Lineplot from a wide-form dataset ================================= _thumb: .52, .5 """ import numpy as np import pandas as pd import seaborn as sns sns.set_theme(style="whitegrid") rs = np.random.RandomState(365) values = rs.randn(365, 4).cumsum(axis=0) dates = pd.date_range("1 1 2016", periods=365, freq="D") data = pd.DataFrame(values, dates, columns=["A", "B", "C", "D"]) data = data.rolling(7).mean() sns.lineplot(data=data, palette="tab10", linewidth=2.5) ================================================ FILE: examples/wide_form_violinplot.py ================================================ """ Violinplot from a wide-form dataset =================================== _thumb: .6, .45 """ import seaborn as sns import matplotlib.pyplot as plt sns.set_theme(style="whitegrid") # Load the example dataset of brain network correlations df = sns.load_dataset("brain_networks", header=[0, 1, 2], index_col=0) # Pull out a specific subset of networks used_networks = [1, 3, 4, 5, 6, 7, 8, 11, 12, 13, 16, 17] used_columns = (df.columns.get_level_values("network") .astype(int) .isin(used_networks)) df = df.loc[:, used_columns] # Compute the correlation matrix and average over networks corr_df = df.corr().groupby(level="network").mean() corr_df.index = corr_df.index.astype(int) corr_df = corr_df.sort_index().T # Set up the matplotlib figure f, ax = plt.subplots(figsize=(11, 6)) # Draw a violinplot with a narrower bandwidth than the default sns.violinplot(data=corr_df, bw_adjust=.5, cut=1, linewidth=1, palette="Set3") # Finalize the figure ax.set(ylim=(-.7, 1.05)) sns.despine(left=True, bottom=True) ================================================ FILE: licences/APPDIRS_LICENSE ================================================ Copyright (c) 2005-2010 ActiveState Software Inc. Copyright (c) 2013 Eddy Petrișor This file is directly from https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py The license of https://github.com/ActiveState/appdirs copied below: # This is the MIT license Copyright (c) 2010 ActiveState Software Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: licences/HUSL_LICENSE ================================================ Copyright (C) 2012 Alexei Boronine Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: licences/NUMPYDOC_LICENSE ================================================ Copyright (C) 2008 Stefan van der Walt , Pauli Virtanen Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: licences/PACKAGING_LICENSE ================================================ Copyright (c) Donald Stufft and individual contributors. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: licences/SCIPY_LICENSE ================================================ Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["flit_core >=3.2,<4"] build-backend = "flit_core.buildapi" [project] name = "seaborn" description = "Statistical data visualization" authors = [{name = "Michael Waskom", email = "mwaskom@gmail.com"}] readme = "README.md" license = {file = "LICENSE.md"} dynamic = ["version"] classifiers = [ "Intended Audience :: Science/Research", "License :: OSI Approved :: BSD License", "Topic :: Scientific/Engineering :: Visualization", "Topic :: Multimedia :: Graphics", "Operating System :: OS Independent", "Framework :: Matplotlib", ] requires-python = ">=3.10" dependencies = [ "numpy>=1.20,!=1.24.0", "pandas>=1.2", "matplotlib>=3.4,!=3.6.1", ] [project.optional-dependencies] stats = [ "scipy>=1.7", "statsmodels>=0.12", ] dev = [ "pytest", "pytest-cov", "pytest-xdist", "flake8", "mypy", "pandas-stubs", "pre-commit", "flit", ] docs = [ "numpydoc", "nbconvert", "ipykernel", "sphinx<6.0.0", "sphinx-copybutton", "sphinx-issues", "sphinx-design", "pyyaml", "pydata_sphinx_theme==0.10.0rc2", ] [project.urls] Source = "https://github.com/mwaskom/seaborn" Docs = "http://seaborn.pydata.org" [tool.flit.sdist] exclude = ["doc/_static/*.svg"] [tool.pytest.ini_options] filterwarnings = [ "ignore:The --rsyncdir command line argument and rsyncdirs config variable are deprecated.:DeprecationWarning", "ignore:\\s*Pyarrow will become a required dependency of pandas:DeprecationWarning", "ignore:datetime.datetime.utcfromtimestamp\\(\\) is deprecated:DeprecationWarning", ] ================================================ FILE: seaborn/__init__.py ================================================ # Import seaborn objects from .rcmod import * # noqa: F401,F403 from .utils import * # noqa: F401,F403 from .palettes import * # noqa: F401,F403 from .relational import * # noqa: F401,F403 from .regression import * # noqa: F401,F403 from .categorical import * # noqa: F401,F403 from .distributions import * # noqa: F401,F403 from .matrix import * # noqa: F401,F403 from .miscplot import * # noqa: F401,F403 from .axisgrid import * # noqa: F401,F403 from .widgets import * # noqa: F401,F403 from .colors import xkcd_rgb, crayons # noqa: F401 from . import cm # noqa: F401 # Capture the original matplotlib rcParams import matplotlib as mpl _orig_rc_params = mpl.rcParams.copy() # Define the seaborn version __version__ = "0.14.0.dev0" ================================================ FILE: seaborn/_base.py ================================================ from __future__ import annotations import warnings import itertools from copy import copy from collections import UserString from collections.abc import Iterable, Sequence, Mapping from numbers import Number from datetime import datetime import numpy as np import pandas as pd import matplotlib as mpl from seaborn._core.data import PlotData from seaborn.palettes import ( QUAL_PALETTES, color_palette, ) from seaborn.utils import ( _check_argument, _version_predates, desaturate, locator_to_legend_entries, get_color_cycle, remove_na, ) class SemanticMapping: """Base class for mapping data values to plot attributes.""" # -- Default attributes that all SemanticMapping subclasses must set # Whether the mapping is numeric, categorical, or datetime map_type: str | None = None # Ordered list of unique values in the input data levels = None # A mapping from the data values to corresponding plot attributes lookup_table = None def __init__(self, plotter): # TODO Putting this here so we can continue to use a lot of the # logic that's built into the library, but the idea of this class # is to move towards semantic mappings that are agnostic about the # kind of plot they're going to be used to draw. # Fully achieving that is going to take some thinking. self.plotter = plotter def _check_list_length(self, levels, values, variable): """Input check when values are provided as a list.""" # Copied from _core/properties; eventually will be replaced for that. message = "" if len(levels) > len(values): message = " ".join([ f"\nThe {variable} list has fewer values ({len(values)})", f"than needed ({len(levels)}) and will cycle, which may", "produce an uninterpretable plot." ]) values = [x for _, x in zip(levels, itertools.cycle(values))] elif len(values) > len(levels): message = " ".join([ f"The {variable} list has more values ({len(values)})", f"than needed ({len(levels)}), which may not be intended.", ]) values = values[:len(levels)] if message: warnings.warn(message, UserWarning, stacklevel=6) return values def _lookup_single(self, key): """Apply the mapping to a single data value.""" return self.lookup_table[key] def __call__(self, key, *args, **kwargs): """Get the attribute(s) values for the data key.""" if isinstance(key, (list, np.ndarray, pd.Series)): return [self._lookup_single(k, *args, **kwargs) for k in key] else: return self._lookup_single(key, *args, **kwargs) class HueMapping(SemanticMapping): """Mapping that sets artist colors according to data values.""" # A specification of the colors that should appear in the plot palette = None # An object that normalizes data values to [0, 1] range for color mapping norm = None # A continuous colormap object for interpolating in a numeric context cmap = None def __init__( self, plotter, palette=None, order=None, norm=None, saturation=1, ): """Map the levels of the `hue` variable to distinct colors. Parameters ---------- # TODO add generic parameters """ super().__init__(plotter) data = plotter.plot_data.get("hue", pd.Series(dtype=float)) if isinstance(palette, np.ndarray): msg = ( "Numpy array is not a supported type for `palette`. " "Please convert your palette to a list. " "This will become an error in v0.14" ) warnings.warn(msg, stacklevel=4) palette = palette.tolist() if data.isna().all(): if palette is not None: msg = "Ignoring `palette` because no `hue` variable has been assigned." warnings.warn(msg, stacklevel=4) else: map_type = self.infer_map_type( palette, norm, plotter.input_format, plotter.var_types["hue"] ) # Our goal is to end up with a dictionary mapping every unique # value in `data` to a color. We will also keep track of the # metadata about this mapping we will need for, e.g., a legend # --- Option 1: numeric mapping with a matplotlib colormap if map_type == "numeric": data = pd.to_numeric(data) levels, lookup_table, norm, cmap = self.numeric_mapping( data, palette, norm, ) # --- Option 2: categorical mapping using seaborn palette elif map_type == "categorical": cmap = norm = None levels, lookup_table = self.categorical_mapping( data, palette, order, ) # --- Option 3: datetime mapping else: # TODO this needs actual implementation cmap = norm = None levels, lookup_table = self.categorical_mapping( # Casting data to list to handle differences in the way # pandas and numpy represent datetime64 data list(data), palette, order, ) self.saturation = saturation self.map_type = map_type self.lookup_table = lookup_table self.palette = palette self.levels = levels self.norm = norm self.cmap = cmap def _lookup_single(self, key): """Get the color for a single value, using colormap to interpolate.""" try: # Use a value that's in the original data vector value = self.lookup_table[key] except KeyError: if self.norm is None: # Currently we only get here in scatterplot with hue_order, # because scatterplot does not consider hue a grouping variable # So unused hue levels are in the data, but not the lookup table return (0, 0, 0, 0) # Use the colormap to interpolate between existing datapoints # (e.g. in the context of making a continuous legend) try: normed = self.norm(key) except TypeError as err: if np.isnan(key): value = (0, 0, 0, 0) else: raise err else: if np.ma.is_masked(normed): normed = np.nan value = self.cmap(normed) if self.saturation < 1: value = desaturate(value, self.saturation) return value def infer_map_type(self, palette, norm, input_format, var_type): """Determine how to implement the mapping.""" if palette in QUAL_PALETTES: map_type = "categorical" elif norm is not None: map_type = "numeric" elif isinstance(palette, (dict, list)): map_type = "categorical" elif input_format == "wide": map_type = "categorical" else: map_type = var_type return map_type def categorical_mapping(self, data, palette, order): """Determine colors when the hue mapping is categorical.""" # -- Identify the order and name of the levels levels = categorical_order(data, order) n_colors = len(levels) # -- Identify the set of colors to use if isinstance(palette, dict): missing = set(levels) - set(palette) if any(missing): err = "The palette dictionary is missing keys: {}" raise ValueError(err.format(missing)) lookup_table = palette else: if palette is None: if n_colors <= len(get_color_cycle()): colors = color_palette(None, n_colors) else: colors = color_palette("husl", n_colors) elif isinstance(palette, list): colors = self._check_list_length(levels, palette, "palette") else: colors = color_palette(palette, n_colors) lookup_table = dict(zip(levels, colors)) return levels, lookup_table def numeric_mapping(self, data, palette, norm): """Determine colors when the hue variable is quantitative.""" if isinstance(palette, dict): # The presence of a norm object overrides a dictionary of hues # in specifying a numeric mapping, so we need to process it here. levels = list(sorted(palette)) colors = [palette[k] for k in sorted(palette)] cmap = mpl.colors.ListedColormap(colors) lookup_table = palette.copy() else: # The levels are the sorted unique values in the data levels = list(np.sort(remove_na(data.unique()))) # --- Sort out the colormap to use from the palette argument # Default numeric palette is our default cubehelix palette # TODO do we want to do something complicated to ensure contrast? palette = "ch:" if palette is None else palette if isinstance(palette, mpl.colors.Colormap): cmap = palette else: cmap = color_palette(palette, as_cmap=True) # Now sort out the data normalization if norm is None: norm = mpl.colors.Normalize() elif isinstance(norm, tuple): norm = mpl.colors.Normalize(*norm) elif not isinstance(norm, mpl.colors.Normalize): err = "``hue_norm`` must be None, tuple, or Normalize object." raise ValueError(err) if not norm.scaled(): norm(np.asarray(data.dropna())) lookup_table = dict(zip(levels, cmap(norm(levels)))) return levels, lookup_table, norm, cmap class SizeMapping(SemanticMapping): """Mapping that sets artist sizes according to data values.""" # An object that normalizes data values to [0, 1] range norm = None def __init__( self, plotter, sizes=None, order=None, norm=None, ): """Map the levels of the `size` variable to distinct values. Parameters ---------- # TODO add generic parameters """ super().__init__(plotter) data = plotter.plot_data.get("size", pd.Series(dtype=float)) if data.notna().any(): map_type = self.infer_map_type( norm, sizes, plotter.var_types["size"] ) # --- Option 1: numeric mapping if map_type == "numeric": levels, lookup_table, norm, size_range = self.numeric_mapping( data, sizes, norm, ) # --- Option 2: categorical mapping elif map_type == "categorical": levels, lookup_table = self.categorical_mapping( data, sizes, order, ) size_range = None # --- Option 3: datetime mapping # TODO this needs an actual implementation else: levels, lookup_table = self.categorical_mapping( # Casting data to list to handle differences in the way # pandas and numpy represent datetime64 data list(data), sizes, order, ) size_range = None self.map_type = map_type self.levels = levels self.norm = norm self.sizes = sizes self.size_range = size_range self.lookup_table = lookup_table def infer_map_type(self, norm, sizes, var_type): if norm is not None: map_type = "numeric" elif isinstance(sizes, (dict, list)): map_type = "categorical" else: map_type = var_type return map_type def _lookup_single(self, key): try: value = self.lookup_table[key] except KeyError: normed = self.norm(key) if np.ma.is_masked(normed): normed = np.nan value = self.size_range[0] + normed * np.ptp(self.size_range) return value def categorical_mapping(self, data, sizes, order): levels = categorical_order(data, order) if isinstance(sizes, dict): # Dict inputs map existing data values to the size attribute missing = set(levels) - set(sizes) if any(missing): err = f"Missing sizes for the following levels: {missing}" raise ValueError(err) lookup_table = sizes.copy() elif isinstance(sizes, list): # List inputs give size values in the same order as the levels sizes = self._check_list_length(levels, sizes, "sizes") lookup_table = dict(zip(levels, sizes)) else: if isinstance(sizes, tuple): # Tuple input sets the min, max size values if len(sizes) != 2: err = "A `sizes` tuple must have only 2 values" raise ValueError(err) elif sizes is not None: err = f"Value for `sizes` not understood: {sizes}" raise ValueError(err) else: # Otherwise, we need to get the min, max size values from # the plotter object we are attached to. # TODO this is going to cause us trouble later, because we # want to restructure things so that the plotter is generic # across the visual representation of the data. But at this # point, we don't know the visual representation. Likely we # want to change the logic of this Mapping so that it gives # points on a normalized range that then gets un-normalized # when we know what we're drawing. But given the way the # package works now, this way is cleanest. sizes = self.plotter._default_size_range # For categorical sizes, use regularly-spaced linear steps # between the minimum and maximum sizes. Then reverse the # ramp so that the largest value is used for the first entry # in size_order, etc. This is because "ordered" categories # are often though to go in decreasing priority. sizes = np.linspace(*sizes, len(levels))[::-1] lookup_table = dict(zip(levels, sizes)) return levels, lookup_table def numeric_mapping(self, data, sizes, norm): if isinstance(sizes, dict): # The presence of a norm object overrides a dictionary of sizes # in specifying a numeric mapping, so we need to process it # dictionary here levels = list(np.sort(list(sizes))) size_values = sizes.values() size_range = min(size_values), max(size_values) else: # The levels here will be the unique values in the data levels = list(np.sort(remove_na(data.unique()))) if isinstance(sizes, tuple): # For numeric inputs, the size can be parametrized by # the minimum and maximum artist values to map to. The # norm object that gets set up next specifies how to # do the mapping. if len(sizes) != 2: err = "A `sizes` tuple must have only 2 values" raise ValueError(err) size_range = sizes elif sizes is not None: err = f"Value for `sizes` not understood: {sizes}" raise ValueError(err) else: # When not provided, we get the size range from the plotter # object we are attached to. See the note in the categorical # method about how this is suboptimal for future development. size_range = self.plotter._default_size_range # Now that we know the minimum and maximum sizes that will get drawn, # we need to map the data values that we have into that range. We will # use a matplotlib Normalize class, which is typically used for numeric # color mapping but works fine here too. It takes data values and maps # them into a [0, 1] interval, potentially nonlinear-ly. if norm is None: # Default is a linear function between the min and max data values norm = mpl.colors.Normalize() elif isinstance(norm, tuple): # It is also possible to give different limits in data space norm = mpl.colors.Normalize(*norm) elif not isinstance(norm, mpl.colors.Normalize): err = f"Value for size `norm` parameter not understood: {norm}" raise ValueError(err) else: # If provided with Normalize object, copy it so we can modify norm = copy(norm) # Set the mapping so all output values are in [0, 1] norm.clip = True # If the input range is not set, use the full range of the data if not norm.scaled(): norm(levels) # Map from data values to [0, 1] range sizes_scaled = norm(levels) # Now map from the scaled range into the artist units if isinstance(sizes, dict): lookup_table = sizes else: lo, hi = size_range sizes = lo + sizes_scaled * (hi - lo) lookup_table = dict(zip(levels, sizes)) return levels, lookup_table, norm, size_range class StyleMapping(SemanticMapping): """Mapping that sets artist style according to data values.""" # Style mapping is always treated as categorical map_type = "categorical" def __init__(self, plotter, markers=None, dashes=None, order=None): """Map the levels of the `style` variable to distinct values. Parameters ---------- # TODO add generic parameters """ super().__init__(plotter) data = plotter.plot_data.get("style", pd.Series(dtype=float)) if data.notna().any(): # Cast to list to handle numpy/pandas datetime quirks if variable_type(data) == "datetime": data = list(data) # Find ordered unique values levels = categorical_order(data, order) markers = self._map_attributes( markers, levels, unique_markers(len(levels)), "markers", ) dashes = self._map_attributes( dashes, levels, unique_dashes(len(levels)), "dashes", ) # Build the paths matplotlib will use to draw the markers paths = {} filled_markers = [] for k, m in markers.items(): if not isinstance(m, mpl.markers.MarkerStyle): m = mpl.markers.MarkerStyle(m) paths[k] = m.get_path().transformed(m.get_transform()) filled_markers.append(m.is_filled()) # Mixture of filled and unfilled markers will show line art markers # in the edge color, which defaults to white. This can be handled, # but there would be additional complexity with specifying the # weight of the line art markers without overwhelming the filled # ones with the edges. So for now, we will disallow mixtures. if any(filled_markers) and not all(filled_markers): err = "Filled and line art markers cannot be mixed" raise ValueError(err) lookup_table = {} for key in levels: lookup_table[key] = {} if markers: lookup_table[key]["marker"] = markers[key] lookup_table[key]["path"] = paths[key] if dashes: lookup_table[key]["dashes"] = dashes[key] self.levels = levels self.lookup_table = lookup_table def _lookup_single(self, key, attr=None): """Get attribute(s) for a given data point.""" if attr is None: value = self.lookup_table[key] else: value = self.lookup_table[key][attr] return value def _map_attributes(self, arg, levels, defaults, attr): """Handle the specification for a given style attribute.""" if arg is True: lookup_table = dict(zip(levels, defaults)) elif isinstance(arg, dict): missing = set(levels) - set(arg) if missing: err = f"These `{attr}` levels are missing values: {missing}" raise ValueError(err) lookup_table = arg elif isinstance(arg, Sequence): arg = self._check_list_length(levels, arg, attr) lookup_table = dict(zip(levels, arg)) elif arg: err = f"This `{attr}` argument was not understood: {arg}" raise ValueError(err) else: lookup_table = {} return lookup_table # =========================================================================== # class VectorPlotter: """Base class for objects underlying *plot functions.""" wide_structure = { "x": "@index", "y": "@values", "hue": "@columns", "style": "@columns", } flat_structure = {"x": "@index", "y": "@values"} _default_size_range = 1, 2 # Unused but needed in tests, ugh def __init__(self, data=None, variables={}): self._var_levels = {} # var_ordered is relevant only for categorical axis variables, and may # be better handled by an internal axis information object that tracks # such information and is set up by the scale_* methods. The analogous # information for numeric axes would be information about log scales. self._var_ordered = {"x": False, "y": False} # alt., used DefaultDict self.assign_variables(data, variables) # TODO Lots of tests assume that these are called to initialize the # mappings to default values on class initialization. I'd prefer to # move away from that and only have a mapping when explicitly called. for var in ["hue", "size", "style"]: if var in variables: getattr(self, f"map_{var}")() @property def has_xy_data(self): """Return True at least one of x or y is defined.""" return bool({"x", "y"} & set(self.variables)) @property def var_levels(self): """Property interface to ordered list of variables levels. Each time it's accessed, it updates the var_levels dictionary with the list of levels in the current semantic mappers. But it also allows the dictionary to persist, so it can be used to set levels by a key. This is used to track the list of col/row levels using an attached FacetGrid object, but it's kind of messy and ideally fixed by improving the faceting logic so it interfaces better with the modern approach to tracking plot variables. """ for var in self.variables: if (map_obj := getattr(self, f"_{var}_map", None)) is not None: self._var_levels[var] = map_obj.levels return self._var_levels def assign_variables(self, data=None, variables={}): """Define plot variables, optionally using lookup from `data`.""" x = variables.get("x", None) y = variables.get("y", None) if x is None and y is None: self.input_format = "wide" frame, names = self._assign_variables_wideform(data, **variables) else: # When dealing with long-form input, use the newer PlotData # object (internal but introduced for the objects interface) # to centralize / standardize data consumption logic. self.input_format = "long" plot_data = PlotData(data, variables) frame = plot_data.frame names = plot_data.names self.plot_data = frame self.variables = names self.var_types = { v: variable_type( frame[v], boolean_type="numeric" if v in "xy" else "categorical" ) for v in names } return self def _assign_variables_wideform(self, data=None, **kwargs): """Define plot variables given wide-form data. Parameters ---------- data : flat vector or collection of vectors Data can be a vector or mapping that is coerceable to a Series or a sequence- or mapping-based collection of such vectors, or a rectangular numpy array, or a Pandas DataFrame. kwargs : variable -> data mappings Behavior with keyword arguments is currently undefined. Returns ------- plot_data : :class:`pandas.DataFrame` Long-form data object mapping seaborn variables (x, y, hue, ...) to data vectors. variables : dict Keys are defined seaborn variables; values are names inferred from the inputs (or None when no name can be determined). """ # Raise if semantic or other variables are assigned in wide-form mode assigned = [k for k, v in kwargs.items() if v is not None] if any(assigned): s = "s" if len(assigned) > 1 else "" err = f"The following variable{s} cannot be assigned with wide-form data: " err += ", ".join(f"`{v}`" for v in assigned) raise ValueError(err) # Determine if the data object actually has any data in it empty = data is None or not len(data) # Then, determine if we have "flat" data (a single vector) if isinstance(data, dict): values = data.values() else: values = np.atleast_1d(np.asarray(data, dtype=object)) flat = not any( isinstance(v, Iterable) and not isinstance(v, (str, bytes)) for v in values ) if empty: # Make an object with the structure of plot_data, but empty plot_data = pd.DataFrame() variables = {} elif flat: # Handle flat data by converting to pandas Series and using the # index and/or values to define x and/or y # (Could be accomplished with a more general to_series() interface) flat_data = pd.Series(data).copy() names = { "@values": flat_data.name, "@index": flat_data.index.name } plot_data = {} variables = {} for var in ["x", "y"]: if var in self.flat_structure: attr = self.flat_structure[var] plot_data[var] = getattr(flat_data, attr[1:]) variables[var] = names[self.flat_structure[var]] plot_data = pd.DataFrame(plot_data) else: # Otherwise assume we have some collection of vectors. # Handle Python sequences such that entries end up in the columns, # not in the rows, of the intermediate wide DataFrame. # One way to accomplish this is to convert to a dict of Series. if isinstance(data, Sequence): data_dict = {} for i, var in enumerate(data): key = getattr(var, "name", i) # TODO is there a safer/more generic way to ensure Series? # sort of like np.asarray, but for pandas? data_dict[key] = pd.Series(var) data = data_dict # Pandas requires that dict values either be Series objects # or all have the same length, but we want to allow "ragged" inputs if isinstance(data, Mapping): data = {key: pd.Series(val) for key, val in data.items()} # Otherwise, delegate to the pandas DataFrame constructor # This is where we'd prefer to use a general interface that says # "give me this data as a pandas DataFrame", so we can accept # DataFrame objects from other libraries wide_data = pd.DataFrame(data, copy=True) # At this point we should reduce the dataframe to numeric cols numeric_cols = [ k for k, v in wide_data.items() if variable_type(v) == "numeric" ] wide_data = wide_data[numeric_cols] # Now melt the data to long form melt_kws = {"var_name": "@columns", "value_name": "@values"} use_index = "@index" in self.wide_structure.values() if use_index: melt_kws["id_vars"] = "@index" try: orig_categories = wide_data.columns.categories orig_ordered = wide_data.columns.ordered except AttributeError: category_columns = False else: category_columns = True wide_data["@index"] = wide_data.index.to_series() plot_data = wide_data.melt(**melt_kws) if use_index and category_columns: plot_data["@columns"] = pd.Categorical(plot_data["@columns"], orig_categories, orig_ordered) # Assign names corresponding to plot semantics for var, attr in self.wide_structure.items(): plot_data[var] = plot_data[attr] # Define the variable names variables = {} for var, attr in self.wide_structure.items(): obj = getattr(wide_data, attr[1:]) variables[var] = getattr(obj, "name", None) # Remove redundant columns from plot_data plot_data = plot_data[list(variables)] return plot_data, variables def map_hue(self, palette=None, order=None, norm=None, saturation=1): mapping = HueMapping(self, palette, order, norm, saturation) self._hue_map = mapping def map_size(self, sizes=None, order=None, norm=None): mapping = SizeMapping(self, sizes, order, norm) self._size_map = mapping def map_style(self, markers=None, dashes=None, order=None): mapping = StyleMapping(self, markers, dashes, order) self._style_map = mapping def iter_data( self, grouping_vars=None, *, reverse=False, from_comp_data=False, by_facet=True, allow_empty=False, dropna=True, ): """Generator for getting subsets of data defined by semantic variables. Also injects "col" and "row" into grouping semantics. Parameters ---------- grouping_vars : string or list of strings Semantic variables that define the subsets of data. reverse : bool If True, reverse the order of iteration. from_comp_data : bool If True, use self.comp_data rather than self.plot_data by_facet : bool If True, add faceting variables to the set of grouping variables. allow_empty : bool If True, yield an empty dataframe when no observations exist for combinations of grouping variables. dropna : bool If True, remove rows with missing data. Yields ------ sub_vars : dict Keys are semantic names, values are the level of that semantic. sub_data : :class:`pandas.DataFrame` Subset of ``plot_data`` for this combination of semantic values. """ # TODO should this default to using all (non x/y?) semantics? # or define grouping vars somewhere? if grouping_vars is None: grouping_vars = [] elif isinstance(grouping_vars, str): grouping_vars = [grouping_vars] elif isinstance(grouping_vars, tuple): grouping_vars = list(grouping_vars) # Always insert faceting variables if by_facet: facet_vars = {"col", "row"} grouping_vars.extend( facet_vars & set(self.variables) - set(grouping_vars) ) # Reduce to the semantics used in this plot grouping_vars = [var for var in grouping_vars if var in self.variables] if from_comp_data: data = self.comp_data else: data = self.plot_data if dropna: data = data.dropna() levels = self.var_levels.copy() if from_comp_data: for axis in {"x", "y"} & set(grouping_vars): converter = self.converters[axis].iloc[0] if self.var_types[axis] == "categorical": if self._var_ordered[axis]: # If the axis is ordered, then the axes in a possible # facet grid are by definition "shared", or there is a # single axis with a unique cat -> idx mapping. # So we can just take the first converter object. levels[axis] = converter.convert_units(levels[axis]) else: # Otherwise, the mappings may not be unique, but we can # use the unique set of index values in comp_data. levels[axis] = np.sort(data[axis].unique()) else: transform = converter.get_transform().transform levels[axis] = transform(converter.convert_units(levels[axis])) if grouping_vars: grouped_data = data.groupby( grouping_vars, sort=False, as_index=False, observed=False, ) grouping_keys = [] for var in grouping_vars: key = levels.get(var) grouping_keys.append([] if key is None else key) iter_keys = itertools.product(*grouping_keys) if reverse: iter_keys = reversed(list(iter_keys)) for key in iter_keys: pd_key = ( key[0] if len(key) == 1 and _version_predates(pd, "2.2.0") else key ) try: data_subset = grouped_data.get_group(pd_key) except KeyError: # XXX we are adding this to allow backwards compatibility # with the empty artists that old categorical plots would # add (before 0.12), which we may decide to break, in which # case this option could be removed data_subset = data.loc[[]] if data_subset.empty and not allow_empty: continue sub_vars = dict(zip(grouping_vars, key)) yield sub_vars, data_subset.copy() else: yield {}, data.copy() @property def comp_data(self): """Dataframe with numeric x and y, after unit conversion and log scaling.""" if not hasattr(self, "ax"): # Probably a good idea, but will need a bunch of tests updated # Most of these tests should just use the external interface # Then this can be re-enabled. # raise AttributeError("No Axes attached to plotter") return self.plot_data if not hasattr(self, "_comp_data"): comp_data = ( self.plot_data .copy(deep=False) .drop(["x", "y"], axis=1, errors="ignore") ) for var in "yx": if var not in self.variables: continue parts = [] grouped = self.plot_data[var].groupby(self.converters[var], sort=False) for converter, orig in grouped: orig = orig.mask(orig.isin([np.inf, -np.inf]), np.nan) orig = orig.dropna() if var in self.var_levels: # TODO this should happen in some centralized location # it is similar to GH2419, but more complicated because # supporting `order` in categorical plots is tricky orig = orig[orig.isin(self.var_levels[var])] comp = pd.to_numeric(converter.convert_units(orig)).astype(float) transform = converter.get_transform().transform parts.append(pd.Series(transform(comp), orig.index, name=orig.name)) if parts: comp_col = pd.concat(parts) else: comp_col = pd.Series(dtype=float, name=var) comp_data.insert(0, var, comp_col) self._comp_data = comp_data return self._comp_data def _get_axes(self, sub_vars): """Return an Axes object based on existence of row/col variables.""" row = sub_vars.get("row", None) col = sub_vars.get("col", None) if row is not None and col is not None: return self.facets.axes_dict[(row, col)] elif row is not None: return self.facets.axes_dict[row] elif col is not None: return self.facets.axes_dict[col] elif self.ax is None: return self.facets.ax else: return self.ax def _attach( self, obj, allowed_types=None, log_scale=None, ): """Associate the plotter with an Axes manager and initialize its units. Parameters ---------- obj : :class:`matplotlib.axes.Axes` or :class:'FacetGrid` Structural object that we will eventually plot onto. allowed_types : str or list of str If provided, raise when either the x or y variable does not have one of the declared seaborn types. log_scale : bool, number, or pair of bools or numbers If not False, set the axes to use log scaling, with the given base or defaulting to 10. If a tuple, interpreted as separate arguments for the x and y axes. """ from .axisgrid import FacetGrid if isinstance(obj, FacetGrid): self.ax = None self.facets = obj ax_list = obj.axes.flatten() if obj.col_names is not None: self.var_levels["col"] = obj.col_names if obj.row_names is not None: self.var_levels["row"] = obj.row_names else: self.ax = obj self.facets = None ax_list = [obj] # Identify which "axis" variables we have defined axis_variables = set("xy").intersection(self.variables) # -- Verify the types of our x and y variables here. # This doesn't really make complete sense being here here, but it's a fine # place for it, given the current system. # (Note that for some plots, there might be more complicated restrictions) # e.g. the categorical plots have their own check that as specific to the # non-categorical axis. if allowed_types is None: allowed_types = ["numeric", "datetime", "categorical"] elif isinstance(allowed_types, str): allowed_types = [allowed_types] for var in axis_variables: var_type = self.var_types[var] if var_type not in allowed_types: err = ( f"The {var} variable is {var_type}, but one of " f"{allowed_types} is required" ) raise TypeError(err) # -- Get axis objects for each row in plot_data for type conversions and scaling facet_dim = {"x": "col", "y": "row"} self.converters = {} for var in axis_variables: other_var = {"x": "y", "y": "x"}[var] converter = pd.Series(index=self.plot_data.index, name=var, dtype=object) share_state = getattr(self.facets, f"_share{var}", True) # Simplest cases are that we have a single axes, all axes are shared, # or sharing is only on the orthogonal facet dimension. In these cases, # all datapoints get converted the same way, so use the first axis if share_state is True or share_state == facet_dim[other_var]: converter.loc[:] = getattr(ax_list[0], f"{var}axis") else: # Next simplest case is when no axes are shared, and we can # use the axis objects within each facet if share_state is False: for axes_vars, axes_data in self.iter_data(): ax = self._get_axes(axes_vars) converter.loc[axes_data.index] = getattr(ax, f"{var}axis") # In the more complicated case, the axes are shared within each # "file" of the facetgrid. In that case, we need to subset the data # for that file and assign it the first axis in the slice of the grid else: names = getattr(self.facets, f"{share_state}_names") for i, level in enumerate(names): idx = (i, 0) if share_state == "row" else (0, i) axis = getattr(self.facets.axes[idx], f"{var}axis") converter.loc[self.plot_data[share_state] == level] = axis # Store the converter vector, which we use elsewhere (e.g comp_data) self.converters[var] = converter # Now actually update the matplotlib objects to do the conversion we want grouped = self.plot_data[var].groupby(self.converters[var], sort=False) for converter, seed_data in grouped: if self.var_types[var] == "categorical": if self._var_ordered[var]: order = self.var_levels[var] else: order = None seed_data = categorical_order(seed_data, order) converter.update_units(seed_data) # -- Set numerical axis scales # First unpack the log_scale argument if log_scale is None: scalex = scaley = False else: # Allow single value or x, y tuple try: scalex, scaley = log_scale except TypeError: scalex = log_scale if self.var_types.get("x") == "numeric" else False scaley = log_scale if self.var_types.get("y") == "numeric" else False # Now use it for axis, scale in zip("xy", (scalex, scaley)): if scale: for ax in ax_list: set_scale = getattr(ax, f"set_{axis}scale") if scale is True: set_scale("log", nonpositive="mask") else: set_scale("log", base=scale, nonpositive="mask") # For categorical y, we want the "first" level to be at the top of the axis if self.var_types.get("y", None) == "categorical": for ax in ax_list: ax.yaxis.set_inverted(True) # TODO -- Add axes labels def _get_scale_transforms(self, axis): """Return a function implementing the scale transform (or its inverse).""" if self.ax is None: axis_list = [getattr(ax, f"{axis}axis") for ax in self.facets.axes.flat] scales = {axis.get_scale() for axis in axis_list} if len(scales) > 1: # It is a simplifying assumption that faceted axes will always have # the same scale (even if they are unshared and have distinct limits). # Nothing in the seaborn API allows you to create a FacetGrid with # a mixture of scales, although it's possible via matplotlib. # This is constraining, but no more so than previous behavior that # only (properly) handled log scales, and there are some places where # it would be much too complicated to use axes-specific transforms. err = "Cannot determine transform with mixed scales on faceted axes." raise RuntimeError(err) transform_obj = axis_list[0].get_transform() else: # This case is more straightforward transform_obj = getattr(self.ax, f"{axis}axis").get_transform() return transform_obj.transform, transform_obj.inverted().transform def _add_axis_labels(self, ax, default_x="", default_y=""): """Add axis labels if not present, set visibility to match ticklabels.""" # TODO ax could default to None and use attached axes if present # but what to do about the case of facets? Currently using FacetGrid's # set_axis_labels method, which doesn't add labels to the interior even # when the axes are not shared. Maybe that makes sense? if not ax.get_xlabel(): x_visible = any(t.get_visible() for t in ax.get_xticklabels()) ax.set_xlabel(self.variables.get("x", default_x), visible=x_visible) if not ax.get_ylabel(): y_visible = any(t.get_visible() for t in ax.get_yticklabels()) ax.set_ylabel(self.variables.get("y", default_y), visible=y_visible) def add_legend_data( self, ax, func, common_kws=None, attrs=None, semantic_kws=None, ): """Add labeled artists to represent the different plot semantics.""" verbosity = self.legend if isinstance(verbosity, str) and verbosity not in ["auto", "brief", "full"]: err = "`legend` must be 'auto', 'brief', 'full', or a boolean." raise ValueError(err) elif verbosity is True: verbosity = "auto" keys = [] legend_kws = {} common_kws = {} if common_kws is None else common_kws.copy() semantic_kws = {} if semantic_kws is None else semantic_kws.copy() # Assign a legend title if there is only going to be one sub-legend, # otherwise, subtitles will be inserted into the texts list with an # invisible handle (which is a hack) titles = { title for title in (self.variables.get(v, None) for v in ["hue", "size", "style"]) if title is not None } title = "" if len(titles) != 1 else titles.pop() title_kws = dict( visible=False, color="w", s=0, linewidth=0, marker="", dashes="" ) def update(var_name, val_name, **kws): key = var_name, val_name if key in legend_kws: legend_kws[key].update(**kws) else: keys.append(key) legend_kws[key] = dict(**kws) if attrs is None: attrs = {"hue": "color", "size": ["linewidth", "s"], "style": None} for var, names in attrs.items(): self._update_legend_data( update, var, verbosity, title, title_kws, names, semantic_kws.get(var), ) legend_data = {} legend_order = [] # Don't allow color=None so we can set a neutral color for size/style legends if common_kws.get("color", False) is None: common_kws.pop("color") for key in keys: _, label = key kws = legend_kws[key] level_kws = {} use_attrs = [ *self._legend_attributes, *common_kws, *[attr for var_attrs in semantic_kws.values() for attr in var_attrs], ] for attr in use_attrs: if attr in kws: level_kws[attr] = kws[attr] artist = func(label=label, **{"color": ".2", **common_kws, **level_kws}) if _version_predates(mpl, "3.5.0"): if isinstance(artist, mpl.lines.Line2D): ax.add_line(artist) elif isinstance(artist, mpl.patches.Patch): ax.add_patch(artist) elif isinstance(artist, mpl.collections.Collection): ax.add_collection(artist) else: ax.add_artist(artist) legend_data[key] = artist legend_order.append(key) self.legend_title = title self.legend_data = legend_data self.legend_order = legend_order def _update_legend_data( self, update, var, verbosity, title, title_kws, attr_names, other_props, ): """Generate legend tick values and formatted labels.""" brief_ticks = 6 mapper = getattr(self, f"_{var}_map", None) if mapper is None: return brief = mapper.map_type == "numeric" and ( verbosity == "brief" or (verbosity == "auto" and len(mapper.levels) > brief_ticks) ) if brief: if isinstance(mapper.norm, mpl.colors.LogNorm): locator = mpl.ticker.LogLocator(numticks=brief_ticks) else: locator = mpl.ticker.MaxNLocator(nbins=brief_ticks) limits = min(mapper.levels), max(mapper.levels) levels, formatted_levels = locator_to_legend_entries( locator, limits, self.plot_data[var].infer_objects().dtype ) elif mapper.levels is None: levels = formatted_levels = [] else: levels = formatted_levels = mapper.levels if not title and self.variables.get(var, None) is not None: update((self.variables[var], "title"), self.variables[var], **title_kws) other_props = {} if other_props is None else other_props for level, formatted_level in zip(levels, formatted_levels): if level is not None: attr = mapper(level) if isinstance(attr_names, list): attr = {name: attr for name in attr_names} elif attr_names is not None: attr = {attr_names: attr} attr.update({k: v[level] for k, v in other_props.items() if level in v}) update(self.variables[var], formatted_level, **attr) # XXX If the scale_* methods are going to modify the plot_data structure, they # can't be called twice. That means that if they are called twice, they should # raise. Alternatively, we could store an original version of plot_data and each # time they are called they operate on the store, not the current state. def scale_native(self, axis, *args, **kwargs): # Default, defer to matplotlib raise NotImplementedError def scale_numeric(self, axis, *args, **kwargs): # Feels needed to completeness, what should it do? # Perhaps handle log scaling? Set the ticker/formatter/limits? raise NotImplementedError def scale_datetime(self, axis, *args, **kwargs): # Use pd.to_datetime to convert strings or numbers to datetime objects # Note, use day-resolution for numeric->datetime to match matplotlib raise NotImplementedError def scale_categorical(self, axis, order=None, formatter=None): """ Enforce categorical (fixed-scale) rules for the data on given axis. Parameters ---------- axis : "x" or "y" Axis of the plot to operate on. order : list Order that unique values should appear in. formatter : callable Function mapping values to a string representation. Returns ------- self """ # This method both modifies the internal representation of the data # (converting it to string) and sets some attributes on self. It might be # a good idea to have a separate object attached to self that contains the # information in those attributes (i.e. whether to enforce variable order # across facets, the order to use) similar to the SemanticMapping objects # we have for semantic variables. That object could also hold the converter # objects that get used, if we can decouple those from an existing axis # (cf. https://github.com/matplotlib/matplotlib/issues/19229). # There are some interactions with faceting information that would need # to be thought through, since the converts to use depend on facets. # If we go that route, these methods could become "borrowed" methods similar # to what happens with the alternate semantic mapper constructors, although # that approach is kind of fussy and confusing. # TODO this method could also set the grid state? Since we like to have no # grid on the categorical axis by default. Again, a case where we'll need to # store information until we use it, so best to have a way to collect the # attributes that this method sets. # TODO if we are going to set visual properties of the axes with these methods, # then we could do the steps currently in CategoricalPlotter._adjust_cat_axis # TODO another, and distinct idea, is to expose a cut= param here _check_argument("axis", ["x", "y"], axis) # Categorical plots can be "univariate" in which case they get an anonymous # category label on the opposite axis. if axis not in self.variables: self.variables[axis] = None self.var_types[axis] = "categorical" self.plot_data[axis] = "" # If the "categorical" variable has a numeric type, sort the rows so that # the default result from categorical_order has those values sorted after # they have been coerced to strings. The reason for this is so that later # we can get facet-wise orders that are correct. # XXX Should this also sort datetimes? # It feels more consistent, but technically will be a default change # If so, should also change categorical_order to behave that way if self.var_types[axis] == "numeric": self.plot_data = self.plot_data.sort_values(axis, kind="mergesort") # Now get a reference to the categorical data vector and remove na values cat_data = self.plot_data[axis].dropna() # Get the initial categorical order, which we do before string # conversion to respect the original types of the order list. # Track whether the order is given explicitly so that we can know # whether or not to use the order constructed here downstream self._var_ordered[axis] = order is not None or cat_data.dtype.name == "category" order = pd.Index(categorical_order(cat_data, order), name=axis) # Then convert data to strings. This is because in matplotlib, # "categorical" data really mean "string" data, so doing this artists # will be drawn on the categorical axis with a fixed scale. # TODO implement formatter here; check that it returns strings? if formatter is not None: cat_data = cat_data.map(formatter) order = order.map(formatter) else: cat_data = cat_data.astype(str) order = order.astype(str) # Update the levels list with the type-converted order variable self.var_levels[axis] = order # Now ensure that seaborn will use categorical rules internally self.var_types[axis] = "categorical" # Put the string-typed categorical vector back into the plot_data structure self.plot_data[axis] = cat_data return self class VariableType(UserString): """ Prevent comparisons elsewhere in the library from using the wrong name. Errors are simple assertions because users should not be able to trigger them. If that changes, they should be more verbose. """ # TODO we can replace this with typing.Literal on Python 3.8+ allowed = "numeric", "datetime", "categorical" def __init__(self, data): assert data in self.allowed, data super().__init__(data) def __eq__(self, other): assert other in self.allowed, other return self.data == other def variable_type(vector, boolean_type="numeric"): """ Determine whether a vector contains numeric, categorical, or datetime data. This function differs from the pandas typing API in two ways: - Python sequences or object-typed PyData objects are considered numeric if all of their entries are numeric. - String or mixed-type data are considered categorical even if not explicitly represented as a :class:`pandas.api.types.CategoricalDtype`. Parameters ---------- vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence Input data to test. boolean_type : 'numeric' or 'categorical' Type to use for vectors containing only 0s and 1s (and NAs). Returns ------- var_type : 'numeric', 'categorical', or 'datetime' Name identifying the type of data in the vector. """ vector = pd.Series(vector) # If a categorical dtype is set, infer categorical if isinstance(vector.dtype, pd.CategoricalDtype): return VariableType("categorical") # Special-case all-na data, which is always "numeric" if pd.isna(vector).all(): return VariableType("numeric") # At this point, drop nans to simplify further type inference vector = vector.dropna() # Special-case binary/boolean data, allow caller to determine # This triggers a numpy warning when vector has strings/objects # https://github.com/numpy/numpy/issues/6784 # Because we reduce with .all(), we are agnostic about whether the # comparison returns a scalar or vector, so we will ignore the warning. # It triggers a separate DeprecationWarning when the vector has datetimes: # https://github.com/numpy/numpy/issues/13548 # This is considered a bug by numpy and will likely go away. with warnings.catch_warnings(): warnings.simplefilter( action='ignore', category=(FutureWarning, DeprecationWarning) ) try: if np.isin(vector, [0, 1]).all(): return VariableType(boolean_type) except TypeError: # .isin comparison is not guaranteed to be possible under NumPy # casting rules, depending on the (unknown) dtype of 'vector' pass # Defer to positive pandas tests if pd.api.types.is_numeric_dtype(vector): return VariableType("numeric") if pd.api.types.is_datetime64_dtype(vector): return VariableType("datetime") # --- If we get to here, we need to check the entries # Check for a collection where everything is a number def all_numeric(x): for x_i in x: if not isinstance(x_i, Number): return False return True if all_numeric(vector): return VariableType("numeric") # Check for a collection where everything is a datetime def all_datetime(x): for x_i in x: if not isinstance(x_i, (datetime, np.datetime64)): return False return True if all_datetime(vector): return VariableType("datetime") # Otherwise, our final fallback is to consider things categorical return VariableType("categorical") def infer_orient(x=None, y=None, orient=None, require_numeric=True): """Determine how the plot should be oriented based on the data. For historical reasons, the convention is to call a plot "horizontally" or "vertically" oriented based on the axis representing its dependent variable. Practically, this is used when determining the axis for numerical aggregation. Parameters ---------- x, y : Vector data or None Positional data vectors for the plot. orient : string or None Specified orientation. If not None, can be "x" or "y", or otherwise must start with "v" or "h". require_numeric : bool If set, raise when the implied dependent variable is not numeric. Returns ------- orient : "x" or "y" Raises ------ ValueError: When `orient` is an unknown string. TypeError: When dependent variable is not numeric, with `require_numeric` """ x_type = None if x is None else variable_type(x) y_type = None if y is None else variable_type(y) nonnumeric_dv_error = "{} orientation requires numeric `{}` variable." single_var_warning = "{} orientation ignored with only `{}` specified." if x is None: if str(orient).startswith("h"): warnings.warn(single_var_warning.format("Horizontal", "y")) if require_numeric and y_type != "numeric": raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) return "x" elif y is None: if str(orient).startswith("v"): warnings.warn(single_var_warning.format("Vertical", "x")) if require_numeric and x_type != "numeric": raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) return "y" elif str(orient).startswith("v") or orient == "x": if require_numeric and y_type != "numeric": raise TypeError(nonnumeric_dv_error.format("Vertical", "y")) return "x" elif str(orient).startswith("h") or orient == "y": if require_numeric and x_type != "numeric": raise TypeError(nonnumeric_dv_error.format("Horizontal", "x")) return "y" elif orient is not None: err = ( "`orient` must start with 'v' or 'h' or be None, " f"but `{repr(orient)}` was passed." ) raise ValueError(err) elif x_type != "categorical" and y_type == "categorical": return "y" elif x_type != "numeric" and y_type == "numeric": return "x" elif x_type == "numeric" and y_type != "numeric": return "y" elif require_numeric and "numeric" not in (x_type, y_type): err = "Neither the `x` nor `y` variable appears to be numeric." raise TypeError(err) else: return "x" def unique_dashes(n): """Build an arbitrarily long list of unique dash styles for lines. Parameters ---------- n : int Number of unique dash specs to generate. Returns ------- dashes : list of strings or tuples Valid arguments for the ``dashes`` parameter on :class:`matplotlib.lines.Line2D`. The first spec is a solid line (``""``), the remainder are sequences of long and short dashes. """ # Start with dash specs that are well distinguishable dashes = [ "", (4, 1.5), (1, 1), (3, 1.25, 1.5, 1.25), (5, 1, 1, 1), ] # Now programmatically build as many as we need p = 3 while len(dashes) < n: # Take combinations of long and short dashes a = itertools.combinations_with_replacement([3, 1.25], p) b = itertools.combinations_with_replacement([4, 1], p) # Interleave the combinations, reversing one of the streams segment_list = itertools.chain(*zip( list(a)[1:-1][::-1], list(b)[1:-1] )) # Now insert the gaps for segments in segment_list: gap = min(segments) spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) dashes.append(spec) p += 1 return dashes[:n] def unique_markers(n): """Build an arbitrarily long list of unique marker styles for points. Parameters ---------- n : int Number of unique marker specs to generate. Returns ------- markers : list of string or tuples Values for defining :class:`matplotlib.markers.MarkerStyle` objects. All markers will be filled. """ # Start with marker specs that are well distinguishable markers = [ "o", "X", (4, 0, 45), "P", (4, 0, 0), (4, 1, 0), "^", (4, 1, 45), "v", ] # Now generate more from regular polygons of increasing order s = 5 while len(markers) < n: a = 360 / (s + 1) / 2 markers.extend([ (s + 1, 1, a), (s + 1, 0, a), (s, 1, 0), (s, 0, 0), ]) s += 1 # Convert to MarkerStyle object, using only exactly what we need # markers = [mpl.markers.MarkerStyle(m) for m in markers[:n]] return markers[:n] def categorical_order(vector, order=None): """Return a list of unique data values. Determine an ordered list of levels in ``values``. Parameters ---------- vector : list, array, Categorical, or Series Vector of "categorical" values order : list-like, optional Desired order of category levels to override the order determined from the ``values`` object. Returns ------- order : list Ordered list of category levels not including null values. """ if order is None: if hasattr(vector, "categories"): order = vector.categories else: try: order = vector.cat.categories except (TypeError, AttributeError): order = pd.Series(vector).unique() if variable_type(vector) == "numeric": order = np.sort(order) order = filter(pd.notnull, order) return list(order) ================================================ FILE: seaborn/_compat.py ================================================ from __future__ import annotations from typing import Literal import numpy as np import pandas as pd import matplotlib as mpl from matplotlib.figure import Figure from seaborn.utils import _version_predates def norm_from_scale(scale, norm): """Produce a Normalize object given a Scale and min/max domain limits.""" # This is an internal maplotlib function that simplifies things to access # It is likely to become part of the matplotlib API at some point: # https://github.com/matplotlib/matplotlib/issues/20329 if isinstance(norm, mpl.colors.Normalize): return norm if scale is None: return None if norm is None: vmin = vmax = None else: vmin, vmax = norm # TODO more helpful error if this fails? class ScaledNorm(mpl.colors.Normalize): def __call__(self, value, clip=None): # From github.com/matplotlib/matplotlib/blob/v3.4.2/lib/matplotlib/colors.py # See github.com/matplotlib/matplotlib/tree/v3.4.2/LICENSE value, is_scalar = self.process_value(value) self.autoscale_None(value) if self.vmin > self.vmax: raise ValueError("vmin must be less or equal to vmax") if self.vmin == self.vmax: return np.full_like(value, 0) if clip is None: clip = self.clip if clip: value = np.clip(value, self.vmin, self.vmax) # ***** Seaborn changes start **** t_value = self.transform(value).reshape(np.shape(value)) t_vmin, t_vmax = self.transform([self.vmin, self.vmax]) # ***** Seaborn changes end ***** if not np.isfinite([t_vmin, t_vmax]).all(): raise ValueError("Invalid vmin or vmax") t_value -= t_vmin t_value /= (t_vmax - t_vmin) t_value = np.ma.masked_invalid(t_value, copy=False) return t_value[0] if is_scalar else t_value new_norm = ScaledNorm(vmin, vmax) new_norm.transform = scale.get_transform().transform return new_norm def get_colormap(name): """Handle changes to matplotlib colormap interface in 3.6.""" try: return mpl.colormaps[name] except AttributeError: return mpl.cm.get_cmap(name) def register_colormap(name, cmap): """Handle changes to matplotlib colormap interface in 3.6.""" try: if name not in mpl.colormaps: mpl.colormaps.register(cmap, name=name) except AttributeError: mpl.cm.register_cmap(name, cmap) def set_layout_engine( fig: Figure, engine: Literal["constrained", "compressed", "tight", "none"], ) -> None: """Handle changes to auto layout engine interface in 3.6""" if hasattr(fig, "set_layout_engine"): fig.set_layout_engine(engine) else: # _version_predates(mpl, 3.6) if engine == "tight": fig.set_tight_layout(True) # type: ignore # predates typing elif engine == "constrained": fig.set_constrained_layout(True) # type: ignore elif engine == "none": fig.set_tight_layout(False) # type: ignore fig.set_constrained_layout(False) # type: ignore def get_layout_engine(fig: Figure) -> mpl.layout_engine.LayoutEngine | None: """Handle changes to auto layout engine interface in 3.6""" if hasattr(fig, "get_layout_engine"): return fig.get_layout_engine() else: # _version_predates(mpl, 3.6) return None def share_axis(ax0, ax1, which): """Handle changes to post-hoc axis sharing.""" if _version_predates(mpl, "3.5"): group = getattr(ax0, f"get_shared_{which}_axes")() group.join(ax1, ax0) else: getattr(ax1, f"share{which}")(ax0) def get_legend_handles(legend): """Handle legendHandles attribute rename.""" if _version_predates(mpl, "3.7"): return legend.legendHandles else: return legend.legend_handles def groupby_apply_include_groups(val): if _version_predates(pd, "2.2.0"): return {} return {"include_groups": val} def get_converter(axis): if _version_predates(mpl, "3.10.0"): return axis.converter return axis.get_converter() ================================================ FILE: seaborn/_core/__init__.py ================================================ ================================================ FILE: seaborn/_core/data.py ================================================ """ Components for parsing variable assignments and internally representing plot data. """ from __future__ import annotations from collections.abc import Mapping, Sized from typing import cast import pandas as pd from pandas import DataFrame from seaborn._core.typing import DataSource, VariableSpec, ColumnName class PlotData: """ Data table with plot variable schema and mapping to original names. Contains logic for parsing variable specification arguments and updating the table with layer-specific data and/or mappings. Parameters ---------- data Input data where variable names map to vector values. variables Keys are names of plot variables (x, y, ...) each value is one of: - name of a column (or index level, or dictionary entry) in `data` - vector in any format that can construct a :class:`pandas.DataFrame` Attributes ---------- frame Data table with column names having defined plot variables. names Dictionary mapping plot variable names to names in source data structure(s). ids Dictionary mapping plot variable names to unique data source identifiers. """ frame: DataFrame frames: dict[tuple, DataFrame] names: dict[str, str | None] ids: dict[str, str | int] source_data: DataSource source_vars: dict[str, VariableSpec] def __init__( self, data: DataSource, variables: dict[str, VariableSpec], ): data = handle_data_source(data) frame, names, ids = self._assign_variables(data, variables) self.frame = frame self.names = names self.ids = ids # The reason we possibly have a dictionary of frames is to support the # Plot.pair operation, post scaling, where each x/y variable needs its # own frame. This feels pretty clumsy and there are a bunch of places in # the client code with awkard if frame / elif frames constructions. # It would be great to have a cleaner abstraction here. self.frames = {} self.source_data = data self.source_vars = variables def __contains__(self, key: str) -> bool: """Boolean check on whether a variable is defined in this dataset.""" if self.frame is None: return any(key in df for df in self.frames.values()) return key in self.frame def join( self, data: DataSource, variables: dict[str, VariableSpec] | None, ) -> PlotData: """Add, replace, or drop variables and return as a new dataset.""" # Inherit the original source of the upstream data by default if data is None: data = self.source_data # TODO allow `data` to be a function (that is called on the source data?) if not variables: variables = self.source_vars # Passing var=None implies that we do not want that variable in this layer disinherit = [k for k, v in variables.items() if v is None] # Create a new dataset with just the info passed here new = PlotData(data, variables) # -- Update the inherited DataSource with this new information drop_cols = [k for k in self.frame if k in new.frame or k in disinherit] parts = [self.frame.drop(columns=drop_cols), new.frame] # Because we are combining distinct columns, this is perhaps more # naturally thought of as a "merge"/"join". But using concat because # some simple testing suggests that it is marginally faster. frame = pd.concat(parts, axis=1, sort=False) names = {k: v for k, v in self.names.items() if k not in disinherit} names.update(new.names) ids = {k: v for k, v in self.ids.items() if k not in disinherit} ids.update(new.ids) new.frame = frame new.names = names new.ids = ids # Multiple chained operations should always inherit from the original object new.source_data = self.source_data new.source_vars = self.source_vars return new def _assign_variables( self, data: DataFrame | Mapping | None, variables: dict[str, VariableSpec], ) -> tuple[DataFrame, dict[str, str | None], dict[str, str | int]]: """ Assign values for plot variables given long-form data and/or vector inputs. Parameters ---------- data Input data where variable names map to vector values. variables Keys are names of plot variables (x, y, ...) each value is one of: - name of a column (or index level, or dictionary entry) in `data` - vector in any format that can construct a :class:`pandas.DataFrame` Returns ------- frame Table mapping seaborn variables (x, y, color, ...) to data vectors. names Keys are defined seaborn variables; values are names inferred from the inputs (or None when no name can be determined). ids Like the `names` dict, but `None` values are replaced by the `id()` of the data object that defined the variable. Raises ------ TypeError When data source is not a DataFrame or Mapping. ValueError When variables are strings that don't appear in `data`, or when they are non-indexed vector datatypes that have a different length from `data`. """ source_data: Mapping | DataFrame frame: DataFrame names: dict[str, str | None] ids: dict[str, str | int] plot_data = {} names = {} ids = {} given_data = data is not None if data is None: # Data is optional; all variables can be defined as vectors # But simplify downstream code by always having a usable source data object source_data = {} else: source_data = data # Variables can also be extracted from the index of a DataFrame if isinstance(source_data, pd.DataFrame): index = source_data.index.to_frame().to_dict("series") else: index = {} for key, val in variables.items(): # Simply ignore variables with no specification if val is None: continue # Try to treat the argument as a key for the data collection. # But be flexible about what can be used as a key. # Usually it will be a string, but allow other hashables when # taking from the main data object. Allow only strings to reference # fields in the index, because otherwise there is too much ambiguity. # TODO this will be rendered unnecessary by the following pandas fix: # https://github.com/pandas-dev/pandas/pull/41283 try: hash(val) val_is_hashable = True except TypeError: val_is_hashable = False val_as_data_key = ( # See https://github.com/pandas-dev/pandas/pull/41283 # (isinstance(val, abc.Hashable) and val in source_data) (val_is_hashable and val in source_data) or (isinstance(val, str) and val in index) ) if val_as_data_key: val = cast(ColumnName, val) if val in source_data: plot_data[key] = source_data[val] elif val in index: plot_data[key] = index[val] names[key] = ids[key] = str(val) elif isinstance(val, str): # This looks like a column name but, lookup failed. err = f"Could not interpret value `{val}` for `{key}`. " if not given_data: err += "Value is a string, but `data` was not passed." else: err += "An entry with this name does not appear in `data`." raise ValueError(err) else: # Otherwise, assume the value somehow represents data # Ignore empty data structures if isinstance(val, Sized) and len(val) == 0: continue # If vector has no index, it must match length of data table if isinstance(data, pd.DataFrame) and not isinstance(val, pd.Series): if isinstance(val, Sized) and len(data) != len(val): val_cls = val.__class__.__name__ err = ( f"Length of {val_cls} vectors must match length of `data`" f" when both are used, but `data` has length {len(data)}" f" and the vector passed to `{key}` has length {len(val)}." ) raise ValueError(err) plot_data[key] = val # Try to infer the original name using pandas-like metadata if hasattr(val, "name"): names[key] = ids[key] = str(val.name) # type: ignore # mypy/1424 else: names[key] = None ids[key] = id(val) # Construct a tidy plot DataFrame. This will convert a number of # types automatically, aligning on index in case of pandas objects # TODO Note: this fails when variable specs *only* have scalars! frame = pd.DataFrame(plot_data) return frame, names, ids def handle_data_source(data: DataSource) -> pd.DataFrame | Mapping | None: """Convert the data source object to a common union representation.""" if isinstance(data, pd.DataFrame) or isinstance(data, Mapping) or data is None: return data elif hasattr(data, "to_pandas"): try: df = data.to_pandas() except Exception as err: msg = ( "Encountered an exception when converting data source " "to a pandas DataFrame. See traceback above for details." ) raise RuntimeError(msg) from err if isinstance(df, pd.DataFrame): return df msg = f"Data source must be a DataFrame or Mapping, not {type(data)!r}." raise TypeError(msg) ================================================ FILE: seaborn/_core/exceptions.py ================================================ """ Custom exceptions for the seaborn.objects interface. This is very lightweight, but it's a separate module to avoid circular imports. """ from __future__ import annotations class PlotSpecError(RuntimeError): """ Error class raised from seaborn.objects.Plot for compile-time failures. In the declarative Plot interface, exceptions may not be triggered immediately by bad user input (and validation at input time may not be possible). This class is used to signal that indirect dependency. It should be raised in an exception chain when compile-time operations fail with an error message providing useful context (e.g., scaling errors could specify the variable that failed.) """ @classmethod def _during(cls, step: str, var: str = "") -> PlotSpecError: """ Initialize the class to report the failure of a specific operation. """ message = [] if var: message.append(f"{step} failed for the `{var}` variable.") else: message.append(f"{step} failed.") message.append("See the traceback above for more information.") return cls(" ".join(message)) ================================================ FILE: seaborn/_core/groupby.py ================================================ """Simplified split-apply-combine paradigm on dataframes for internal use.""" from __future__ import annotations from typing import cast, Iterable import pandas as pd from seaborn._core.rules import categorical_order from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Callable from pandas import DataFrame, MultiIndex, Index class GroupBy: """ Interface for Pandas GroupBy operations allowing specified group order. Writing our own class to do this has a few advantages: - It constrains the interface between Plot and Stat/Move objects - It allows control over the row order of the GroupBy result, which is important when using in the context of some Move operations (dodge, stack, ...) - It simplifies some complexities regarding the return type and Index contents one encounters with Pandas, especially for DataFrame -> DataFrame applies - It increases future flexibility regarding alternate DataFrame libraries """ def __init__(self, order: list[str] | dict[str, list | None]): """ Initialize the GroupBy from grouping variables and optional level orders. Parameters ---------- order List of variable names or dict mapping names to desired level orders. Level order values can be None to use default ordering rules. The variables can include names that are not expected to appear in the data; these will be dropped before the groups are defined. """ if not order: raise ValueError("GroupBy requires at least one grouping variable") if isinstance(order, list): order = {k: None for k in order} self.order = order def _get_groups( self, data: DataFrame ) -> tuple[str | list[str], Index | MultiIndex]: """Return index with Cartesian product of ordered grouping variable levels.""" levels = {} for var, order in self.order.items(): if var in data: if order is None: order = categorical_order(data[var]) levels[var] = order grouper: str | list[str] groups: Index | MultiIndex if not levels: grouper = [] groups = pd.Index([]) elif len(levels) > 1: grouper = list(levels) groups = pd.MultiIndex.from_product(levels.values(), names=grouper) else: grouper, = list(levels) groups = pd.Index(levels[grouper], name=grouper) return grouper, groups def _reorder_columns(self, res, data): """Reorder result columns to match original order with new columns appended.""" cols = [c for c in data if c in res] cols += [c for c in res if c not in data] return res.reindex(columns=pd.Index(cols)) def agg(self, data: DataFrame, *args, **kwargs) -> DataFrame: """ Reduce each group to a single row in the output. The output will have a row for each unique combination of the grouping variable levels with null values for the aggregated variable(s) where those combinations do not appear in the dataset. """ grouper, groups = self._get_groups(data) if not grouper: # We will need to see whether there are valid usecases that end up here raise ValueError("No grouping variables are present in dataframe") res = ( data .groupby(grouper, sort=False, observed=False) .agg(*args, **kwargs) .reindex(groups) .reset_index() .pipe(self._reorder_columns, data) ) return res def apply( self, data: DataFrame, func: Callable[..., DataFrame], *args, **kwargs, ) -> DataFrame: """Apply a DataFrame -> DataFrame mapping to each group.""" grouper, groups = self._get_groups(data) if not grouper: return self._reorder_columns(func(data, *args, **kwargs), data) parts = {} for key, part_df in data.groupby(grouper, sort=False, observed=False): parts[key] = func(part_df, *args, **kwargs) stack = [] for key in groups: if key in parts: if isinstance(grouper, list): # Implies that we had a MultiIndex so key is iterable group_ids = dict(zip(grouper, cast(Iterable, key))) else: group_ids = {grouper: key} stack.append(parts[key].assign(**group_ids)) res = pd.concat(stack, ignore_index=True) return self._reorder_columns(res, data) ================================================ FILE: seaborn/_core/moves.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar, Callable, Optional, Union, cast import numpy as np from pandas import DataFrame from seaborn._core.groupby import GroupBy from seaborn._core.scales import Scale from seaborn._core.typing import Default default = Default() @dataclass class Move: """Base class for objects that apply simple positional transforms.""" group_by_orient: ClassVar[bool] = True def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: raise NotImplementedError @dataclass class Jitter(Move): """ Random displacement along one or both axes to reduce overplotting. Parameters ---------- width : float Magnitude of jitter, relative to mark width, along the orientation axis. If not provided, the default value will be 0 when `x` or `y` are set, otherwise there will be a small amount of jitter applied by default. x : float Magnitude of jitter, in data units, along the x axis. y : float Magnitude of jitter, in data units, along the y axis. Examples -------- .. include:: ../docstrings/objects.Jitter.rst """ width: float | Default = default x: float = 0 y: float = 0 seed: int | None = None def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: data = data.copy() rng = np.random.default_rng(self.seed) def jitter(data, col, scale): noise = rng.uniform(-.5, +.5, len(data)) offsets = noise * scale return data[col] + offsets if self.width is default: width = 0.0 if self.x or self.y else 0.2 else: width = cast(float, self.width) if self.width: data[orient] = jitter(data, orient, width * data["width"]) if self.x: data["x"] = jitter(data, "x", self.x) if self.y: data["y"] = jitter(data, "y", self.y) return data @dataclass class Dodge(Move): """ Displacement and narrowing of overlapping marks along orientation axis. Parameters ---------- empty : {'keep', 'drop', 'fill'} gap : float Size of gap between dodged marks. by : list of variable names Variables to apply the movement to, otherwise use all. Examples -------- .. include:: ../docstrings/objects.Dodge.rst """ empty: str = "keep" # Options: keep, drop, fill gap: float = 0 # TODO accept just a str here? # TODO should this always be present? # TODO should the default be an "all" singleton? by: Optional[list[str]] = None def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: grouping_vars = [v for v in groupby.order if v in data] groups = groupby.agg(data, {"width": "max"}) if self.empty == "fill": groups = groups.dropna() def groupby_pos(s): grouper = [groups[v] for v in [orient, "col", "row"] if v in data] return s.groupby(grouper, sort=False, observed=True) def scale_widths(w): # TODO what value to fill missing widths??? Hard problem... # TODO short circuit this if outer widths has no variance? empty = 0 if self.empty == "fill" else w.mean() filled = w.fillna(empty) scale = filled.max() norm = filled.sum() if self.empty == "keep": w = filled return w / norm * scale def widths_to_offsets(w): return w.shift(1).fillna(0).cumsum() + (w - w.sum()) / 2 new_widths = groupby_pos(groups["width"]).transform(scale_widths) offsets = groupby_pos(new_widths).transform(widths_to_offsets) if self.gap: new_widths *= 1 - self.gap groups["_dodged"] = groups[orient] + offsets groups["width"] = new_widths out = ( data .drop("width", axis=1) .merge(groups, on=grouping_vars, how="left") .drop(orient, axis=1) .rename(columns={"_dodged": orient}) ) return out @dataclass class Stack(Move): """ Displacement of overlapping bar or area marks along the value axis. Examples -------- .. include:: ../docstrings/objects.Stack.rst """ # TODO center? (or should this be a different move, eg. Stream()) def _stack(self, df, orient): # TODO should stack do something with ymin/ymax style marks? # Should there be an upstream conversion to baseline/height parameterization? if df["baseline"].nunique() > 1: err = "Stack move cannot be used when baselines are already heterogeneous" raise RuntimeError(err) other = {"x": "y", "y": "x"}[orient] stacked_lengths = (df[other] - df["baseline"]).dropna().cumsum() offsets = stacked_lengths.shift(1).fillna(0) df[other] = stacked_lengths df["baseline"] = df["baseline"] + offsets return df def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: # TODO where to ensure that other semantic variables are sorted properly? # TODO why are we not using the passed in groupby here? groupers = ["col", "row", orient] return GroupBy(groupers).apply(data, self._stack, orient) @dataclass class Shift(Move): """ Displacement of all marks with the same magnitude / direction. Parameters ---------- x, y : float Magnitude of shift, in data units, along each axis. Examples -------- .. include:: ../docstrings/objects.Shift.rst """ x: float = 0 y: float = 0 def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: data = data.copy(deep=False) data["x"] = data["x"] + self.x data["y"] = data["y"] + self.y return data @dataclass class Norm(Move): """ Divisive scaling on the value axis after aggregating within groups. Parameters ---------- func : str or callable Function called on each group to define the comparison value. where : str Query string defining the subset used to define the comparison values. by : list of variables Variables used to define aggregation groups. percent : bool If True, multiply the result by 100. Examples -------- .. include:: ../docstrings/objects.Norm.rst """ func: Union[Callable, str] = "max" where: Optional[str] = None by: Optional[list[str]] = None percent: bool = False group_by_orient: ClassVar[bool] = False def _norm(self, df, var): if self.where is None: denom_data = df[var] else: denom_data = df.query(self.where)[var] df[var] = df[var] / denom_data.agg(self.func) if self.percent: df[var] = df[var] * 100 return df def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: other = {"x": "y", "y": "x"}[orient] return groupby.apply(data, self._norm, other) # TODO # @dataclass # class Ridge(Move): # ... ================================================ FILE: seaborn/_core/plot.py ================================================ """The classes for specifying and compiling a declarative visualization.""" from __future__ import annotations import io import os import re import inspect import itertools import textwrap from contextlib import contextmanager from collections import abc from collections.abc import Callable, Generator, Mapping from typing import Any, List, Literal, Optional, cast from xml.etree import ElementTree from cycler import cycler import pandas as pd from pandas import DataFrame, Series, Index import matplotlib as mpl from matplotlib.axes import Axes from matplotlib.artist import Artist from matplotlib.figure import Figure import numpy as np from PIL import Image from seaborn._marks.base import Mark from seaborn._stats.base import Stat from seaborn._core.data import PlotData from seaborn._core.moves import Move from seaborn._core.scales import Scale from seaborn._core.subplots import Subplots from seaborn._core.groupby import GroupBy from seaborn._core.properties import PROPERTIES, Property from seaborn._core.typing import ( DataSource, VariableSpec, VariableSpecList, OrderSpec, Default, ) from seaborn._core.exceptions import PlotSpecError from seaborn._core.rules import categorical_order from seaborn._compat import get_layout_engine, set_layout_engine from seaborn.utils import _version_predates from seaborn.rcmod import axes_style, plotting_context from seaborn.palettes import color_palette from typing import TYPE_CHECKING, TypedDict if TYPE_CHECKING: from matplotlib.figure import SubFigure default = Default() # ---- Definitions for internal specs ---------------------------------------------- # class Layer(TypedDict, total=False): mark: Mark # TODO allow list? stat: Stat | None # TODO allow list? move: Move | list[Move] | None data: PlotData source: DataSource vars: dict[str, VariableSpec] orient: str legend: bool label: str | None class FacetSpec(TypedDict, total=False): variables: dict[str, VariableSpec] structure: dict[str, list[str]] wrap: int | None class PairSpec(TypedDict, total=False): variables: dict[str, VariableSpec] structure: dict[str, list[str]] cross: bool wrap: int | None # --- Local helpers ---------------------------------------------------------------- # @contextmanager def theme_context(params: dict[str, Any]) -> Generator: """Temporarily modify specifc matplotlib rcParams.""" orig_params = {k: mpl.rcParams[k] for k in params} color_codes = "bgrmyck" nice_colors = [*color_palette("deep6"), (.15, .15, .15)] orig_colors = [mpl.colors.colorConverter.colors[x] for x in color_codes] # TODO how to allow this to reflect the color cycle when relevant? try: mpl.rcParams.update(params) for (code, color) in zip(color_codes, nice_colors): mpl.colors.colorConverter.colors[code] = color yield finally: mpl.rcParams.update(orig_params) for (code, color) in zip(color_codes, orig_colors): mpl.colors.colorConverter.colors[code] = color def build_plot_signature(cls): """ Decorator function for giving Plot a useful signature. Currently this mostly saves us some duplicated typing, but we would like eventually to have a way of registering new semantic properties, at which point dynamic signature generation would become more important. """ sig = inspect.signature(cls) params = [ inspect.Parameter("args", inspect.Parameter.VAR_POSITIONAL), inspect.Parameter("data", inspect.Parameter.KEYWORD_ONLY, default=None) ] params.extend([ inspect.Parameter(name, inspect.Parameter.KEYWORD_ONLY, default=None) for name in PROPERTIES ]) new_sig = sig.replace(parameters=params) cls.__signature__ = new_sig known_properties = textwrap.fill( ", ".join([f"|{p}|" for p in PROPERTIES]), width=78, subsequent_indent=" " * 8, ) if cls.__doc__ is not None: # support python -OO mode cls.__doc__ = cls.__doc__.format(known_properties=known_properties) return cls # ---- Plot configuration ---------------------------------------------------------- # class ThemeConfig(mpl.RcParams): """ Configuration object for the Plot.theme, using matplotlib rc parameters. """ THEME_GROUPS = [ "axes", "figure", "font", "grid", "hatch", "legend", "lines", "mathtext", "markers", "patch", "savefig", "scatter", "xaxis", "xtick", "yaxis", "ytick", ] def __init__(self): super().__init__() self.reset() @property def _default(self) -> dict[str, Any]: return { **self._filter_params(mpl.rcParamsDefault), **axes_style("darkgrid"), **plotting_context("notebook"), "axes.prop_cycle": cycler("color", color_palette("deep")), } def reset(self) -> None: """Update the theme dictionary with seaborn's default values.""" self.update(self._default) def update(self, other: dict[str, Any] | None = None, /, **kwds): """Update the theme with a dictionary or keyword arguments of rc parameters.""" if other is not None: theme = self._filter_params(other) else: theme = {} theme.update(kwds) super().update(theme) def _filter_params(self, params: dict[str, Any]) -> dict[str, Any]: """Restruct to thematic rc params.""" return { k: v for k, v in params.items() if any(k.startswith(p) for p in self.THEME_GROUPS) } def _html_table(self, params: dict[str, Any]) -> list[str]: lines = [""] for k, v in params.items(): row = f"" lines.append(row) lines.append("
{k}:{v!r}
") return lines def _repr_html_(self) -> str: repr = [ "
", "
", *self._html_table(self), "
", "
", ] return "\n".join(repr) class DisplayConfig(TypedDict): """Configuration for IPython's rich display hooks.""" format: Literal["png", "svg"] scaling: float hidpi: bool class PlotConfig: """Configuration for default behavior / appearance of class:`Plot` instances.""" def __init__(self): self._theme = ThemeConfig() self._display = {"format": "png", "scaling": .85, "hidpi": True} @property def theme(self) -> dict[str, Any]: """ Dictionary of base theme parameters for :class:`Plot`. Keys and values correspond to matplotlib rc params, as documented here: https://matplotlib.org/stable/tutorials/introductory/customizing.html """ return self._theme @property def display(self) -> DisplayConfig: """ Dictionary of parameters for rich display in Jupyter notebook. Valid parameters: - format ("png" or "svg"): Image format to produce - scaling (float): Relative scaling of embedded image - hidpi (bool): When True, double the DPI while preserving the size """ return self._display # ---- The main interface for declarative plotting --------------------------------- # @build_plot_signature class Plot: """ An interface for declaratively specifying statistical graphics. Plots are constructed by initializing this class and adding one or more layers, comprising a `Mark` and optional `Stat` or `Move`. Additionally, faceting variables or variable pairings may be defined to divide the space into multiple subplots. The mappings from data values to visual properties can be parametrized using scales, although the plot will try to infer good defaults when scales are not explicitly defined. The constructor accepts a data source (a :class:`pandas.DataFrame` or dictionary with columnar values) and variable assignments. Variables can be passed as keys to the data source or directly as data vectors. If multiple data-containing objects are provided, they will be index-aligned. The data source and variables defined in the constructor will be used for all layers in the plot, unless overridden or disabled when adding a layer. The following variables can be defined in the constructor: {known_properties} The `data`, `x`, and `y` variables can be passed as positional arguments or using keywords. Whether the first positional argument is interpreted as a data source or `x` variable depends on its type. The methods of this class return a copy of the instance; use chaining to build up a plot through multiple calls. Methods can be called in any order. Most methods only add information to the plot spec; no actual processing happens until the plot is shown or saved. It is also possible to compile the plot without rendering it to access the lower-level representation. """ config = PlotConfig() _data: PlotData _layers: list[Layer] _scales: dict[str, Scale] _shares: dict[str, bool | str] _limits: dict[str, tuple[Any, Any]] _labels: dict[str, str | Callable[[str], str]] _theme: dict[str, Any] _facet_spec: FacetSpec _pair_spec: PairSpec _figure_spec: dict[str, Any] _subplot_spec: dict[str, Any] _layout_spec: dict[str, Any] def __init__( self, *args: DataSource | VariableSpec, data: DataSource = None, **variables: VariableSpec, ): if args: data, variables = self._resolve_positionals(args, data, variables) unknown = [x for x in variables if x not in PROPERTIES] if unknown: err = f"Plot() got unexpected keyword argument(s): {', '.join(unknown)}" raise TypeError(err) self._data = PlotData(data, variables) self._layers = [] self._scales = {} self._shares = {} self._limits = {} self._labels = {} self._theme = {} self._facet_spec = {} self._pair_spec = {} self._figure_spec = {} self._subplot_spec = {} self._layout_spec = {} self._target = None def _resolve_positionals( self, args: tuple[DataSource | VariableSpec, ...], data: DataSource, variables: dict[str, VariableSpec], ) -> tuple[DataSource, dict[str, VariableSpec]]: """Handle positional arguments, which may contain data / x / y.""" if len(args) > 3: err = "Plot() accepts no more than 3 positional arguments (data, x, y)." raise TypeError(err) if ( isinstance(args[0], (abc.Mapping, pd.DataFrame)) or hasattr(args[0], "to_pandas") ): if data is not None: raise TypeError("`data` given by both name and position.") data, args = args[0], args[1:] if len(args) == 2: x, y = args elif len(args) == 1: x, y = *args, None else: x = y = None for name, var in zip("yx", (y, x)): if var is not None: if name in variables: raise TypeError(f"`{name}` given by both name and position.") # Keep coordinates at the front of the variables dict # Cast type because we know this isn't a DataSource at this point variables = {name: cast(VariableSpec, var), **variables} return data, variables def __add__(self, other): if isinstance(other, Mark) or isinstance(other, Stat): raise TypeError("Sorry, this isn't ggplot! Perhaps try Plot.add?") other_type = other.__class__.__name__ raise TypeError(f"Unsupported operand type(s) for +: 'Plot' and '{other_type}") def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None: if Plot.config.display["format"] != "png": return None return self.plot()._repr_png_() def _repr_svg_(self) -> str | None: if Plot.config.display["format"] != "svg": return None return self.plot()._repr_svg_() def _clone(self) -> Plot: """Generate a new object with the same information as the current spec.""" new = Plot() # TODO any way to enforce that data does not get mutated? new._data = self._data new._layers.extend(self._layers) new._scales.update(self._scales) new._shares.update(self._shares) new._limits.update(self._limits) new._labels.update(self._labels) new._theme.update(self._theme) new._facet_spec.update(self._facet_spec) new._pair_spec.update(self._pair_spec) new._figure_spec.update(self._figure_spec) new._subplot_spec.update(self._subplot_spec) new._layout_spec.update(self._layout_spec) new._target = self._target return new def _theme_with_defaults(self) -> dict[str, Any]: theme = self.config.theme.copy() theme.update(self._theme) return theme @property def _variables(self) -> list[str]: variables = ( list(self._data.frame) + list(self._pair_spec.get("variables", [])) + list(self._facet_spec.get("variables", [])) ) for layer in self._layers: variables.extend(v for v in layer["vars"] if v not in variables) # Coerce to str in return to appease mypy; we know these will only # ever be strings but I don't think we can type a DataFrame that way yet return [str(v) for v in variables] def on(self, target: Axes | SubFigure | Figure) -> Plot: """ Provide existing Matplotlib figure or axes for drawing the plot. When using this method, you will also need to explicitly call a method that triggers compilation, such as :meth:`Plot.show` or :meth:`Plot.save`. If you want to postprocess using matplotlib, you'd need to call :meth:`Plot.plot` first to compile the plot without rendering it. Parameters ---------- target : Axes, SubFigure, or Figure Matplotlib object to use. Passing :class:`matplotlib.axes.Axes` will add artists without otherwise modifying the figure. Otherwise, subplots will be created within the space of the given :class:`matplotlib.figure.Figure` or :class:`matplotlib.figure.SubFigure`. Examples -------- .. include:: ../docstrings/objects.Plot.on.rst """ accepted_types: tuple # Allow tuple of various length accepted_types = ( mpl.axes.Axes, mpl.figure.SubFigure, mpl.figure.Figure ) accepted_types_str = ( f"{mpl.axes.Axes}, {mpl.figure.SubFigure}, or {mpl.figure.Figure}" ) if not isinstance(target, accepted_types): err = ( f"The `Plot.on` target must be an instance of {accepted_types_str}. " f"You passed an instance of {target.__class__} instead." ) raise TypeError(err) new = self._clone() new._target = target return new def add( self, mark: Mark, *transforms: Stat | Move, orient: str | None = None, legend: bool = True, label: str | None = None, data: DataSource = None, **variables: VariableSpec, ) -> Plot: """ Specify a layer of the visualization in terms of mark and data transform(s). This is the main method for specifying how the data should be visualized. It can be called multiple times with different arguments to define a plot with multiple layers. Parameters ---------- mark : :class:`Mark` The visual representation of the data to use in this layer. transforms : :class:`Stat` or :class:`Move` Objects representing transforms to be applied before plotting the data. Currently, at most one :class:`Stat` can be used, and it must be passed first. This constraint will be relaxed in the future. orient : "x", "y", "v", or "h" The orientation of the mark, which also affects how transforms are computed. Typically corresponds to the axis that defines groups for aggregation. The "v" (vertical) and "h" (horizontal) options are synonyms for "x" / "y", but may be more intuitive with some marks. When not provided, an orientation will be inferred from characteristics of the data and scales. legend : bool Option to suppress the mark/mappings for this layer from the legend. label : str A label to use for the layer in the legend, independent of any mappings. data : DataFrame or dict Data source to override the global source provided in the constructor. variables : data vectors or identifiers Additional layer-specific variables, including variables that will be passed directly to the transforms without scaling. Examples -------- .. include:: ../docstrings/objects.Plot.add.rst """ if not isinstance(mark, Mark): msg = f"mark must be a Mark instance, not {type(mark)!r}." raise TypeError(msg) # TODO This API for transforms was a late decision, and previously Plot.add # accepted 0 or 1 Stat instances and 0, 1, or a list of Move instances. # It will take some work to refactor the internals so that Stat and Move are # treated identically, and until then well need to "unpack" the transforms # here and enforce limitations on the order / types. stat: Optional[Stat] move: Optional[List[Move]] error = False if not transforms: stat, move = None, None elif isinstance(transforms[0], Stat): stat = transforms[0] move = [m for m in transforms[1:] if isinstance(m, Move)] error = len(move) != len(transforms) - 1 else: stat = None move = [m for m in transforms if isinstance(m, Move)] error = len(move) != len(transforms) if error: msg = " ".join([ "Transforms must have at most one Stat type (in the first position),", "and all others must be a Move type. Given transform type(s):", ", ".join(str(type(t).__name__) for t in transforms) + "." ]) raise TypeError(msg) new = self._clone() new._layers.append({ "mark": mark, "stat": stat, "move": move, # TODO it doesn't work to supply scalars to variables, but it should "vars": variables, "source": data, "legend": legend, "label": label, "orient": {"v": "x", "h": "y"}.get(orient, orient), # type: ignore }) return new def pair( self, x: VariableSpecList = None, y: VariableSpecList = None, wrap: int | None = None, cross: bool = True, ) -> Plot: """ Produce subplots by pairing multiple `x` and/or `y` variables. Parameters ---------- x, y : sequence(s) of data vectors or identifiers Variables that will define the grid of subplots. wrap : int When using only `x` or `y`, "wrap" subplots across a two-dimensional grid with this many columns (when using `x`) or rows (when using `y`). cross : bool When False, zip the `x` and `y` lists such that the first subplot gets the first pair, the second gets the second pair, etc. Otherwise, create a two-dimensional grid from the cartesian product of the lists. Examples -------- .. include:: ../docstrings/objects.Plot.pair.rst """ # TODO Add transpose= arg, which would then draw pair(y=[...]) across rows # This may also be possible by setting `wrap=1`, but is that too unobvious? # TODO PairGrid features not currently implemented: diagonals, corner pair_spec: PairSpec = {} axes = {"x": [] if x is None else x, "y": [] if y is None else y} for axis, arg in axes.items(): if isinstance(arg, (str, int)): err = f"You must pass a sequence of variable keys to `{axis}`" raise TypeError(err) pair_spec["variables"] = {} pair_spec["structure"] = {} for axis in "xy": keys = [] for i, col in enumerate(axes[axis]): key = f"{axis}{i}" keys.append(key) pair_spec["variables"][key] = col if keys: pair_spec["structure"][axis] = keys if not cross and len(axes["x"]) != len(axes["y"]): err = "Lengths of the `x` and `y` lists must match with cross=False" raise ValueError(err) pair_spec["cross"] = cross pair_spec["wrap"] = wrap new = self._clone() new._pair_spec.update(pair_spec) return new def facet( self, col: VariableSpec = None, row: VariableSpec = None, order: OrderSpec | dict[str, OrderSpec] = None, wrap: int | None = None, ) -> Plot: """ Produce subplots with conditional subsets of the data. Parameters ---------- col, row : data vectors or identifiers Variables used to define subsets along the columns and/or rows of the grid. Can be references to the global data source passed in the constructor. order : list of strings, or dict with dimensional keys Define the order of the faceting variables. wrap : int When using only `col` or `row`, wrap subplots across a two-dimensional grid with this many subplots on the faceting dimension. Examples -------- .. include:: ../docstrings/objects.Plot.facet.rst """ variables: dict[str, VariableSpec] = {} if col is not None: variables["col"] = col if row is not None: variables["row"] = row structure = {} if isinstance(order, dict): for dim in ["col", "row"]: dim_order = order.get(dim) if dim_order is not None: structure[dim] = list(dim_order) elif order is not None: if col is not None and row is not None: err = " ".join([ "When faceting on both col= and row=, passing `order` as a list" "is ambiguous. Use a dict with 'col' and/or 'row' keys instead." ]) raise RuntimeError(err) elif col is not None: structure["col"] = list(order) elif row is not None: structure["row"] = list(order) spec: FacetSpec = { "variables": variables, "structure": structure, "wrap": wrap, } new = self._clone() new._facet_spec.update(spec) return new # TODO def twin()? def scale(self, **scales: Scale) -> Plot: """ Specify mappings from data units to visual properties. Keywords correspond to variables defined in the plot, including coordinate variables (`x`, `y`) and semantic variables (`color`, `pointsize`, etc.). A number of "magic" arguments are accepted, including: - The name of a transform (e.g., `"log"`, `"sqrt"`) - The name of a palette (e.g., `"viridis"`, `"muted"`) - A tuple of values, defining the output range (e.g. `(1, 5)`) - A dict, implying a :class:`Nominal` scale (e.g. `{"a": .2, "b": .5}`) - A list of values, implying a :class:`Nominal` scale (e.g. `["b", "r"]`) For more explicit control, pass a scale spec object such as :class:`Continuous` or :class:`Nominal`. Or pass `None` to use an "identity" scale, which treats data values as literally encoding visual properties. Examples -------- .. include:: ../docstrings/objects.Plot.scale.rst """ new = self._clone() new._scales.update(scales) return new def share(self, **shares: bool | str) -> Plot: """ Control sharing of axis limits and ticks across subplots. Keywords correspond to variables defined in the plot, and values can be boolean (to share across all subplots), or one of "row" or "col" (to share more selectively across one dimension of a grid). Behavior for non-coordinate variables is currently undefined. Examples -------- .. include:: ../docstrings/objects.Plot.share.rst """ new = self._clone() new._shares.update(shares) return new def limit(self, **limits: tuple[Any, Any]) -> Plot: """ Control the range of visible data. Keywords correspond to variables defined in the plot, and values are a `(min, max)` tuple (where either can be `None` to leave unset). Limits apply only to the axis; data outside the visible range are still used for any stat transforms and added to the plot. Behavior for non-coordinate variables is currently undefined. Examples -------- .. include:: ../docstrings/objects.Plot.limit.rst """ new = self._clone() new._limits.update(limits) return new def label( self, *, title: str | None = None, legend: str | None = None, **variables: str | Callable[[str], str] ) -> Plot: """ Control the labels and titles for axes, legends, and subplots. Additional keywords correspond to variables defined in the plot. Values can be one of the following types: - string (used literally; pass "" to clear the default label) - function (called on the default label) For coordinate variables, the value sets the axis label. For semantic variables, the value sets the legend title. For faceting variables, `title=` modifies the subplot-specific label, while `col=` and/or `row=` add a label for the faceting variable. When using a single subplot, `title=` sets its title. The `legend=` parameter sets the title for the "layer" legend (i.e., when using `label` in :meth:`Plot.add`). Examples -------- .. include:: ../docstrings/objects.Plot.label.rst """ new = self._clone() if title is not None: new._labels["title"] = title if legend is not None: new._labels["legend"] = legend new._labels.update(variables) return new def layout( self, *, size: tuple[float, float] | Default = default, engine: str | None | Default = default, extent: tuple[float, float, float, float] | Default = default, ) -> Plot: """ Control the figure size and layout. .. note:: Default figure sizes and the API for specifying the figure size are subject to change in future "experimental" releases of the objects API. The default layout engine may also change. Parameters ---------- size : (width, height) Size of the resulting figure, in inches. Size is inclusive of legend when using pyplot, but not otherwise. engine : {{"tight", "constrained", "none"}} Name of method for automatically adjusting the layout to remove overlap. The default depends on whether :meth:`Plot.on` is used. extent : (left, bottom, right, top) Boundaries of the plot layout, in fractions of the figure size. Takes effect through the layout engine; exact results will vary across engines. Note: the extent includes axis decorations when using a layout engine, but it is exclusive of them when `engine="none"`. Examples -------- .. include:: ../docstrings/objects.Plot.layout.rst """ # TODO add an "auto" mode for figsize that roughly scales with the rcParams # figsize (so that works), but expands to prevent subplots from being squished # Also should we have height=, aspect=, exclusive with figsize? Or working # with figsize when only one is defined? new = self._clone() if size is not default: new._figure_spec["figsize"] = size if engine is not default: new._layout_spec["engine"] = engine if extent is not default: new._layout_spec["extent"] = extent return new # TODO def legend (ugh) def theme(self, config: Mapping[str, Any], /) -> Plot: """ Control the appearance of elements in the plot. .. note:: The API for customizing plot appearance is not yet finalized. Currently, the only valid argument is a dict of matplotlib rc parameters. (This dict must be passed as a positional argument.) It is likely that this method will be enhanced in future releases. Matplotlib rc parameters are documented on the following page: https://matplotlib.org/stable/tutorials/introductory/customizing.html Examples -------- .. include:: ../docstrings/objects.Plot.theme.rst """ new = self._clone() rc = mpl.RcParams(config) new._theme.update(rc) return new def save(self, loc, **kwargs) -> Plot: """ Compile the plot and write it to a buffer or file on disk. Parameters ---------- loc : str, path, or buffer Location on disk to save the figure, or a buffer to write into. kwargs Other keyword arguments are passed through to :meth:`matplotlib.figure.Figure.savefig`. """ # TODO expose important keyword arguments in our signature? with theme_context(self._theme_with_defaults()): self._plot().save(loc, **kwargs) return self def show(self, **kwargs) -> None: """ Compile the plot and display it by hooking into pyplot. Calling this method is not necessary to render a plot in notebook context, but it may be in other environments (e.g., in a terminal). After compiling the plot, it calls :func:`matplotlib.pyplot.show` (passing any keyword parameters). Unlike other :class:`Plot` methods, there is no return value. This should be the last method you call when specifying a plot. """ # TODO make pyplot configurable at the class level, and when not using, # import IPython.display and call on self to populate cell output? # Keep an eye on whether matplotlib implements "attaching" an existing # figure to pyplot: https://github.com/matplotlib/matplotlib/pull/14024 self.plot(pyplot=True).show(**kwargs) def plot(self, pyplot: bool = False) -> Plotter: """ Compile the plot spec and return the Plotter object. """ with theme_context(self._theme_with_defaults()): return self._plot(pyplot) def _plot(self, pyplot: bool = False) -> Plotter: # TODO if we have _target object, pyplot should be determined by whether it # is hooked into the pyplot state machine (how do we check?) plotter = Plotter(pyplot=pyplot, theme=self._theme_with_defaults()) # Process the variable assignments and initialize the figure common, layers = plotter._extract_data(self) plotter._setup_figure(self, common, layers) # Process the scale spec for coordinate variables and transform their data coord_vars = [v for v in self._variables if re.match(r"^x|y", v)] plotter._setup_scales(self, common, layers, coord_vars) # Apply statistical transform(s) plotter._compute_stats(self, layers) # Process scale spec for semantic variables and coordinates computed by stat plotter._setup_scales(self, common, layers) # TODO Remove these after updating other methods # ---- Maybe have debug= param that attaches these when True? plotter._data = common plotter._layers = layers # Process the data for each layer and add matplotlib artists for layer in layers: plotter._plot_layer(self, layer) # Add various figure decorations plotter._make_legend(self) plotter._finalize_figure(self) return plotter # ---- The plot compilation engine ---------------------------------------------- # class Plotter: """ Engine for compiling a :class:`Plot` spec into a Matplotlib figure. This class is not intended to be instantiated directly by users. """ # TODO decide if we ever want these (Plot.plot(debug=True))? _data: PlotData _layers: list[Layer] _figure: Figure def __init__(self, pyplot: bool, theme: dict[str, Any]): self._pyplot = pyplot self._theme = theme self._legend_contents: list[tuple[ tuple[str, str | int], list[Artist], list[str], ]] = [] self._scales: dict[str, Scale] = {} def save(self, loc, **kwargs) -> Plotter: # TODO type args kwargs.setdefault("dpi", 96) try: loc = os.path.expanduser(loc) except TypeError: # loc may be a buffer in which case that would not work pass self._figure.savefig(loc, **kwargs) return self def show(self, **kwargs) -> None: """ Display the plot by hooking into pyplot. This method calls :func:`matplotlib.pyplot.show` with any keyword parameters. """ # TODO if we did not create the Plotter with pyplot, is it possible to do this? # If not we should clearly raise. import matplotlib.pyplot as plt with theme_context(self._theme): plt.show(**kwargs) # TODO API for accessing the underlying matplotlib objects # TODO what else is useful in the public API for this class? def _repr_png_(self) -> tuple[bytes, dict[str, float]] | None: # TODO use matplotlib backend directly instead of going through savefig? # TODO perhaps have self.show() flip a switch to disable this, so that # user does not end up with two versions of the figure in the output # TODO use bbox_inches="tight" like the inline backend? # pro: better results, con: (sometimes) confusing results # Better solution would be to default (with option to change) # to using constrained/tight layout. if Plot.config.display["format"] != "png": return None buffer = io.BytesIO() factor = 2 if Plot.config.display["hidpi"] else 1 scaling = Plot.config.display["scaling"] / factor dpi = 96 * factor # TODO put dpi in Plot.config? with theme_context(self._theme): # TODO _theme_with_defaults? self._figure.savefig(buffer, dpi=dpi, format="png", bbox_inches="tight") data = buffer.getvalue() w, h = Image.open(buffer).size metadata = {"width": w * scaling, "height": h * scaling} return data, metadata def _repr_svg_(self) -> str | None: if Plot.config.display["format"] != "svg": return None # TODO DPI for rasterized artists? scaling = Plot.config.display["scaling"] buffer = io.StringIO() with theme_context(self._theme): # TODO _theme_with_defaults? self._figure.savefig(buffer, format="svg", bbox_inches="tight") root = ElementTree.fromstring(buffer.getvalue()) w = scaling * float(root.attrib["width"][:-2]) h = scaling * float(root.attrib["height"][:-2]) root.attrib.update(width=f"{w}pt", height=f"{h}pt", viewbox=f"0 0 {w} {h}") ElementTree.ElementTree(root).write(out := io.BytesIO()) return out.getvalue().decode() def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]: common_data = ( p._data .join(None, p._facet_spec.get("variables")) .join(None, p._pair_spec.get("variables")) ) layers: list[Layer] = [] for layer in p._layers: spec = layer.copy() spec["data"] = common_data.join(layer.get("source"), layer.get("vars")) layers.append(spec) return common_data, layers def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str: if re.match(r"[xy]\d+", var): key = var if var in p._labels else var[0] else: key = var label: str if key in p._labels: manual_label = p._labels[key] if callable(manual_label) and auto_label is not None: label = manual_label(auto_label) else: label = cast(str, manual_label) elif auto_label is None: label = "" else: label = auto_label return label def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None: # --- Parsing the faceting/pairing parameterization to specify figure grid subplot_spec = p._subplot_spec.copy() facet_spec = p._facet_spec.copy() pair_spec = p._pair_spec.copy() for axis in "xy": if axis in p._shares: subplot_spec[f"share{axis}"] = p._shares[axis] for dim in ["col", "row"]: if dim in common.frame and dim not in facet_spec["structure"]: order = categorical_order(common.frame[dim]) facet_spec["structure"][dim] = order self._subplots = subplots = Subplots(subplot_spec, facet_spec, pair_spec) # --- Figure initialization self._figure = subplots.init_figure( pair_spec, self._pyplot, p._figure_spec, p._target, ) # --- Figure annotation for sub in subplots: ax = sub["ax"] for axis in "xy": axis_key = sub[axis] # ~~ Axis labels # TODO Should we make it possible to use only one x/y label for # all rows/columns in a faceted plot? Maybe using sub{axis}label, # although the alignments of the labels from that method leaves # something to be desired (in terms of how it defines 'centered'). names = [ common.names.get(axis_key), *(layer["data"].names.get(axis_key) for layer in layers) ] auto_label = next((name for name in names if name is not None), None) label = self._resolve_label(p, axis_key, auto_label) ax.set(**{f"{axis}label": label}) # ~~ Decoration visibility # TODO there should be some override (in Plot.layout?) so that # axis / tick labels can be shown on interior shared axes if desired axis_obj = getattr(ax, f"{axis}axis") visible_side = {"x": "bottom", "y": "left"}.get(axis) show_axis_label = ( sub[visible_side] or not p._pair_spec.get("cross", True) or ( axis in p._pair_spec.get("structure", {}) and bool(p._pair_spec.get("wrap")) ) ) axis_obj.get_label().set_visible(show_axis_label) show_tick_labels = ( show_axis_label or subplot_spec.get(f"share{axis}") not in ( True, "all", {"x": "col", "y": "row"}[axis] ) ) for group in ("major", "minor"): side = {"x": "bottom", "y": "left"}[axis] axis_obj.set_tick_params(**{f"label{side}": show_tick_labels}) for t in getattr(axis_obj, f"get_{group}ticklabels")(): t.set_visible(show_tick_labels) # TODO we want right-side titles for row facets in most cases? # Let's have what we currently call "margin titles" but properly using the # ax.set_title interface (see my gist) title_parts = [] for dim in ["col", "row"]: if sub[dim] is not None: val = self._resolve_label(p, "title", f"{sub[dim]}") if dim in p._labels: key = self._resolve_label(p, dim, common.names.get(dim)) val = f"{key} {val}" title_parts.append(val) has_col = sub["col"] is not None has_row = sub["row"] is not None show_title = ( has_col and has_row or (has_col or has_row) and p._facet_spec.get("wrap") or (has_col and sub["top"]) # TODO or has_row and sub["right"] and or has_row # TODO and not ) if title_parts: title = " | ".join(title_parts) title_text = ax.set_title(title) title_text.set_visible(show_title) elif not (has_col or has_row): title = self._resolve_label(p, "title", None) title_text = ax.set_title(title) def _compute_stats(self, spec: Plot, layers: list[Layer]) -> None: grouping_vars = [v for v in PROPERTIES if v not in "xy"] grouping_vars += ["col", "row", "group"] pair_vars = spec._pair_spec.get("structure", {}) for layer in layers: data = layer["data"] mark = layer["mark"] stat = layer["stat"] if stat is None: continue iter_axes = itertools.product(*[ pair_vars.get(axis, [axis]) for axis in "xy" ]) old = data.frame if pair_vars: data.frames = {} data.frame = data.frame.iloc[:0] # TODO to simplify typing for coord_vars in iter_axes: pairings = "xy", coord_vars df = old.copy() scales = self._scales.copy() for axis, var in zip(*pairings): if axis != var: df = df.rename(columns={var: axis}) drop_cols = [x for x in df if re.match(rf"{axis}\d+", str(x))] df = df.drop(drop_cols, axis=1) scales[axis] = scales[var] orient = layer["orient"] or mark._infer_orient(scales) if stat.group_by_orient: grouper = [orient, *grouping_vars] else: grouper = grouping_vars groupby = GroupBy(grouper) res = stat(df, groupby, orient, scales) if pair_vars: data.frames[coord_vars] = res else: data.frame = res def _get_scale( self, p: Plot, var: str, prop: Property, values: Series ) -> Scale: if re.match(r"[xy]\d+", var): key = var if var in p._scales else var[0] else: key = var if key in p._scales: arg = p._scales[key] if arg is None or isinstance(arg, Scale): scale = arg else: scale = prop.infer_scale(arg, values) else: scale = prop.default_scale(values) return scale def _get_subplot_data(self, df, var, view, share_state): if share_state in [True, "all"]: # The all-shared case is easiest, every subplot sees all the data seed_values = df[var] else: # Otherwise, we need to setup separate scales for different subplots if share_state in [False, "none"]: # Fully independent axes are also easy: use each subplot's data idx = self._get_subplot_index(df, view) elif share_state in df: # Sharing within row/col is more complicated use_rows = df[share_state] == view[share_state] idx = df.index[use_rows] else: # This configuration doesn't make much sense, but it's fine idx = df.index seed_values = df.loc[idx, var] return seed_values def _setup_scales( self, p: Plot, common: PlotData, layers: list[Layer], variables: list[str] | None = None, ) -> None: if variables is None: # Add variables that have data but not a scale, which happens # because this method can be called multiple time, to handle # variables added during the Stat transform. variables = [] for layer in layers: variables.extend(layer["data"].frame.columns) for df in layer["data"].frames.values(): variables.extend(str(v) for v in df if v not in variables) variables = [v for v in variables if v not in self._scales] for var in variables: # Determine whether this is a coordinate variable # (i.e., x/y, paired x/y, or derivative such as xmax) m = re.match(r"^(?P(?Px|y)\d*).*", var) if m is None: coord = axis = None else: coord = m["coord"] axis = m["axis"] # Get keys that handle things like x0, xmax, properly where relevant prop_key = var if axis is None else axis scale_key = var if coord is None else coord if prop_key not in PROPERTIES: continue # Concatenate layers, using only the relevant coordinate and faceting vars, # This is unnecessarily wasteful, as layer data will often be redundant. # But figuring out the minimal amount we need is more complicated. cols = [var, "col", "row"] parts = [common.frame.filter(cols)] for layer in layers: parts.append(layer["data"].frame.filter(cols)) for df in layer["data"].frames.values(): parts.append(df.filter(cols)) var_df = pd.concat(parts, ignore_index=True) prop = PROPERTIES[prop_key] scale = self._get_scale(p, scale_key, prop, var_df[var]) if scale_key not in p._variables: # TODO this implies that the variable was added by the stat # It allows downstream orientation inference to work properly. # But it feels rather hacky, so ideally revisit. scale._priority = 0 # type: ignore if axis is None: # We could think about having a broader concept of (un)shared properties # In general, not something you want to do (different scales in facets) # But could make sense e.g. with paired plots. Build later. share_state = None subplots = [] else: share_state = self._subplots.subplot_spec[f"share{axis}"] subplots = [view for view in self._subplots if view[axis] == coord] if scale is None: self._scales[var] = Scale._identity() else: try: self._scales[var] = scale._setup(var_df[var], prop) except Exception as err: raise PlotSpecError._during("Scale setup", var) from err if axis is None or (var != coord and coord in p._variables): # Everything below here applies only to coordinate variables continue # Set up an empty series to receive the transformed values. # We need this to handle piecemeal transforms of categories -> floats. transformed_data = [] for layer in layers: index = layer["data"].frame.index empty_series = pd.Series(dtype=float, index=index, name=var) transformed_data.append(empty_series) for view in subplots: axis_obj = getattr(view["ax"], f"{axis}axis") seed_values = self._get_subplot_data(var_df, var, view, share_state) view_scale = scale._setup(seed_values, prop, axis=axis_obj) view["ax"].set(**{f"{axis}scale": view_scale._matplotlib_scale}) for layer, new_series in zip(layers, transformed_data): layer_df = layer["data"].frame if var not in layer_df: continue idx = self._get_subplot_index(layer_df, view) try: new_series.loc[idx] = view_scale(layer_df.loc[idx, var]) except Exception as err: spec_error = PlotSpecError._during("Scaling operation", var) raise spec_error from err # Now the transformed data series are complete, update the layer data for layer, new_series in zip(layers, transformed_data): layer_df = layer["data"].frame if var in layer_df: layer_df[var] = pd.to_numeric(new_series) def _plot_layer(self, p: Plot, layer: Layer) -> None: data = layer["data"] mark = layer["mark"] move = layer["move"] default_grouping_vars = ["col", "row", "group"] # TODO where best to define? grouping_properties = [v for v in PROPERTIES if v[0] not in "xy"] pair_variables = p._pair_spec.get("structure", {}) for subplots, df, scales in self._generate_pairings(data, pair_variables): orient = layer["orient"] or mark._infer_orient(scales) def get_order(var): # Ignore order for x/y: they have been scaled to numeric indices, # so any original order is no longer valid. Default ordering rules # sorted unique numbers will correctly reconstruct intended order # TODO This is tricky, make sure we add some tests for this if var not in "xy" and var in scales: return getattr(scales[var], "order", None) if orient in df: width = pd.Series(index=df.index, dtype=float) for view in subplots: view_idx = self._get_subplot_data( df, orient, view, p._shares.get(orient) ).index view_df = df.loc[view_idx] if "width" in mark._mappable_props: view_width = mark._resolve(view_df, "width", None) elif "width" in df: view_width = view_df["width"] else: view_width = 0.8 # TODO what default? spacing = scales[orient]._spacing(view_df.loc[view_idx, orient]) width.loc[view_idx] = view_width * spacing df["width"] = width if "baseline" in mark._mappable_props: # TODO what marks should have this? # If we can set baseline with, e.g., Bar(), then the # "other" (e.g. y for x oriented bars) parameterization # is somewhat ambiguous. baseline = mark._resolve(df, "baseline", None) else: # TODO unlike width, we might not want to add baseline to data # if the mark doesn't use it. Practically, there is a concern about # Mark abstraction like Area / Ribbon baseline = 0 if "baseline" not in df else df["baseline"] df["baseline"] = baseline if move is not None: moves = move if isinstance(move, list) else [move] for move_step in moves: move_by = getattr(move_step, "by", None) if move_by is None: move_by = grouping_properties move_groupers = [*move_by, *default_grouping_vars] if move_step.group_by_orient: move_groupers.insert(0, orient) order = {var: get_order(var) for var in move_groupers} groupby = GroupBy(order) df = move_step(df, groupby, orient, scales) df = self._unscale_coords(subplots, df, orient) grouping_vars = mark._grouping_props + default_grouping_vars split_generator = self._setup_split_generator(grouping_vars, df, subplots) mark._plot(split_generator, scales, orient) # TODO is this the right place for this? for view in self._subplots: view["ax"].autoscale_view() if layer["legend"]: self._update_legend_contents(p, mark, data, scales, layer["label"]) def _unscale_coords( self, subplots: list[dict], df: DataFrame, orient: str, ) -> DataFrame: # TODO do we still have numbers in the variable name at this point? coord_cols = [c for c in df if re.match(r"^[xy]\D*$", str(c))] out_df = ( df .drop(coord_cols, axis=1) .reindex(df.columns, axis=1) # So unscaled columns retain their place .copy(deep=False) ) for view in subplots: view_df = self._filter_subplot_data(df, view) axes_df = view_df[coord_cols] for var, values in axes_df.items(): axis = getattr(view["ax"], f"{str(var)[0]}axis") # TODO see https://github.com/matplotlib/matplotlib/issues/22713 transform = axis.get_transform().inverted().transform inverted = transform(values) out_df.loc[values.index, str(var)] = inverted return out_df def _generate_pairings( self, data: PlotData, pair_variables: dict, ) -> Generator[ tuple[list[dict], DataFrame, dict[str, Scale]], None, None ]: # TODO retype return with subplot_spec or similar iter_axes = itertools.product(*[ pair_variables.get(axis, [axis]) for axis in "xy" ]) for x, y in iter_axes: subplots = [] for view in self._subplots: if (view["x"] == x) and (view["y"] == y): subplots.append(view) if data.frame.empty and data.frames: out_df = data.frames[(x, y)].copy() elif not pair_variables: out_df = data.frame.copy() else: if data.frame.empty and data.frames: out_df = data.frames[(x, y)].copy() else: out_df = data.frame.copy() scales = self._scales.copy() if x in out_df: scales["x"] = self._scales[x] if y in out_df: scales["y"] = self._scales[y] for axis, var in zip("xy", (x, y)): if axis != var: out_df = out_df.rename(columns={var: axis}) cols = [col for col in out_df if re.match(rf"{axis}\d+", str(col))] out_df = out_df.drop(cols, axis=1) yield subplots, out_df, scales def _get_subplot_index(self, df: DataFrame, subplot: dict) -> Index: dims = df.columns.intersection(["col", "row"]) if dims.empty: return df.index keep_rows = pd.Series(True, df.index, dtype=bool) for dim in dims: keep_rows &= df[dim] == subplot[dim] return df.index[keep_rows] def _filter_subplot_data(self, df: DataFrame, subplot: dict) -> DataFrame: # TODO note redundancies with preceding function ... needs refactoring dims = df.columns.intersection(["col", "row"]) if dims.empty: return df keep_rows = pd.Series(True, df.index, dtype=bool) for dim in dims: keep_rows &= df[dim] == subplot[dim] return df[keep_rows] def _setup_split_generator( self, grouping_vars: list[str], df: DataFrame, subplots: list[dict[str, Any]], ) -> Callable[[], Generator]: grouping_keys = [] grouping_vars = [ v for v in grouping_vars if v in df and v not in ["col", "row"] ] for var in grouping_vars: order = getattr(self._scales[var], "order", None) if order is None: order = categorical_order(df[var]) grouping_keys.append(order) def split_generator(keep_na=False) -> Generator: for view in subplots: axes_df = self._filter_subplot_data(df, view) axes_df_inf_as_nan = axes_df.copy() axes_df_inf_as_nan = axes_df_inf_as_nan.mask( axes_df_inf_as_nan.isin([np.inf, -np.inf]), np.nan ) if keep_na: # The simpler thing to do would be x.dropna().reindex(x.index). # But that doesn't work with the way that the subset iteration # is written below, which assumes data for grouping vars. # Matplotlib (usually?) masks nan data, so this should "work". # Downstream code can also drop these rows, at some speed cost. present = axes_df_inf_as_nan.notna().all(axis=1) nulled = {} for axis in "xy": if axis in axes_df: nulled[axis] = axes_df[axis].where(present) axes_df = axes_df_inf_as_nan.assign(**nulled) else: axes_df = axes_df_inf_as_nan.dropna() subplot_keys = {} for dim in ["col", "row"]: if view[dim] is not None: subplot_keys[dim] = view[dim] if not grouping_vars or not any(grouping_keys): if not axes_df.empty: yield subplot_keys, axes_df.copy(), view["ax"] continue grouped_df = axes_df.groupby( grouping_vars, sort=False, as_index=False, observed=False, ) for key in itertools.product(*grouping_keys): pd_key = ( key[0] if len(key) == 1 and _version_predates(pd, "2.2.0") else key ) try: df_subset = grouped_df.get_group(pd_key) except KeyError: # TODO (from initial work on categorical plots refactor) # We are adding this to allow backwards compatability # with the empty artists that old categorical plots would # add (before 0.12), which we may decide to break, in which # case this option could be removed df_subset = axes_df.loc[[]] if df_subset.empty: continue sub_vars = dict(zip(grouping_vars, key)) sub_vars.update(subplot_keys) # TODO need copy(deep=...) policy (here, above, anywhere else?) yield sub_vars, df_subset.copy(), view["ax"] return split_generator def _update_legend_contents( self, p: Plot, mark: Mark, data: PlotData, scales: dict[str, Scale], layer_label: str | None, ) -> None: """Add legend artists / labels for one layer in the plot.""" if data.frame.empty and data.frames: legend_vars: list[str] = [] for frame in data.frames.values(): frame_vars = frame.columns.intersection(list(scales)) legend_vars.extend(v for v in frame_vars if v not in legend_vars) else: legend_vars = list(data.frame.columns.intersection(list(scales))) # First handle layer legends, which occupy a single entry in legend_contents. if layer_label is not None: legend_title = str(p._labels.get("legend", "")) layer_key = (legend_title, -1) artist = mark._legend_artist([], None, {}) if artist is not None: for content in self._legend_contents: if content[0] == layer_key: content[1].append(artist) content[2].append(layer_label) break else: self._legend_contents.append((layer_key, [artist], [layer_label])) # Then handle the scale legends # First pass: Identify the values that will be shown for each variable schema: list[tuple[ tuple[str, str | int], list[str], tuple[list[Any], list[str]] ]] = [] schema = [] for var in legend_vars: var_legend = scales[var]._legend if var_legend is not None: values, labels = var_legend for (_, part_id), part_vars, _ in schema: if data.ids[var] == part_id: # Allow multiple plot semantics to represent same data variable part_vars.append(var) break else: title = self._resolve_label(p, var, data.names[var]) entry = (title, data.ids[var]), [var], (values, labels) schema.append(entry) # Second pass, generate an artist corresponding to each value contents: list[tuple[tuple[str, str | int], Any, list[str]]] = [] for key, variables, (values, labels) in schema: artists = [] for val in values: artist = mark._legend_artist(variables, val, scales) if artist is not None: artists.append(artist) if artists: contents.append((key, artists, labels)) self._legend_contents.extend(contents) def _make_legend(self, p: Plot) -> None: """Create the legend artist(s) and add onto the figure.""" # Combine artists representing same information across layers # Input list has an entry for each distinct variable in each layer # Output dict has an entry for each distinct variable merged_contents: dict[ tuple[str, str | int], tuple[list[tuple[Artist, ...]], list[str]], ] = {} for key, new_artists, labels in self._legend_contents: # Key is (name, id); we need the id to resolve variable uniqueness, # but will need the name in the next step to title the legend if key not in merged_contents: # Matplotlib accepts a tuple of artists and will overlay them new_artist_tuples = [tuple([a]) for a in new_artists] merged_contents[key] = new_artist_tuples, labels else: existing_artists = merged_contents[key][0] for i, new_artist in enumerate(new_artists): existing_artists[i] += tuple([new_artist]) # When using pyplot, an "external" legend won't be shown, so this # keeps it inside the axes (though still attached to the figure) # This is necessary because matplotlib layout engines currently don't # support figure legends — ideally this will change. loc = "center right" if self._pyplot else "center left" base_legend = None for (name, _), (handles, labels) in merged_contents.items(): legend = mpl.legend.Legend( self._figure, handles, # type: ignore # matplotlib/issues/26639 labels, title=name, loc=loc, bbox_to_anchor=(.98, .55), ) if base_legend: # Matplotlib has no public API for this so it is a bit of a hack. # Ideally we'd define our own legend class with more flexibility, # but that is a lot of work! base_legend_box = base_legend.get_children()[0] this_legend_box = legend.get_children()[0] base_legend_box.get_children().extend(this_legend_box.get_children()) else: base_legend = legend self._figure.legends.append(legend) def _finalize_figure(self, p: Plot) -> None: for sub in self._subplots: ax = sub["ax"] for axis in "xy": axis_key = sub[axis] axis_obj = getattr(ax, f"{axis}axis") # Axis limits if axis_key in p._limits or axis in p._limits: convert_units = getattr(ax, f"{axis}axis").convert_units a, b = p._limits.get(axis_key) or p._limits[axis] lo = a if a is None else convert_units(a) hi = b if b is None else convert_units(b) if isinstance(a, str): lo = cast(float, lo) - 0.5 if isinstance(b, str): hi = cast(float, hi) + 0.5 ax.set(**{f"{axis}lim": (lo, hi)}) if axis_key in self._scales: # TODO when would it not be? self._scales[axis_key]._finalize(p, axis_obj) if (engine_name := p._layout_spec.get("engine", default)) is not default: # None is a valid arg for Figure.set_layout_engine, hence `default` set_layout_engine(self._figure, engine_name) elif p._target is None: # Don't modify the layout engine if the user supplied their own # matplotlib figure and didn't specify an engine through Plot # TODO switch default to "constrained"? # TODO either way, make configurable set_layout_engine(self._figure, "tight") if (extent := p._layout_spec.get("extent")) is not None: engine = get_layout_engine(self._figure) if engine is None: self._figure.subplots_adjust(*extent) else: # Note the different parameterization for the layout engine rect... left, bottom, right, top = extent width, height = right - left, top - bottom try: # The base LayoutEngine.set method doesn't have rect= so we need # to avoid typechecking this statement. We also catch a TypeError # as a plugin LayoutEngine may not support it either. # Alternatively we could guard this with a check on the engine type, # but that would make later-developed engines would un-useable. engine.set(rect=[left, bottom, width, height]) # type: ignore except TypeError: # Should we warn / raise? Note that we don't expect to get here # under any normal circumstances. pass ================================================ FILE: seaborn/_core/properties.py ================================================ from __future__ import annotations import itertools import warnings import numpy as np from numpy.typing import ArrayLike from pandas import Series import matplotlib as mpl from matplotlib.colors import to_rgb, to_rgba, to_rgba_array from matplotlib.markers import MarkerStyle from matplotlib.path import Path from seaborn._core.scales import Scale, Boolean, Continuous, Nominal, Temporal from seaborn._core.rules import categorical_order, variable_type from seaborn.palettes import QUAL_PALETTES, color_palette, blend_palette from seaborn.utils import get_color_cycle from typing import Any, Callable, Tuple, List, Union, Optional RGBTuple = Tuple[float, float, float] RGBATuple = Tuple[float, float, float, float] ColorSpec = Union[RGBTuple, RGBATuple, str] DashPattern = Tuple[float, ...] DashPatternWithOffset = Tuple[float, Optional[DashPattern]] MarkerPattern = Union[ float, str, Tuple[int, int, float], List[Tuple[float, float]], Path, MarkerStyle, ] Mapping = Callable[[ArrayLike], ArrayLike] # =================================================================================== # # Base classes # =================================================================================== # class Property: """Base class for visual properties that can be set directly or be data scaling.""" # When True, scales for this property will populate the legend by default legend = False # When True, scales for this property normalize data to [0, 1] before mapping normed = False def __init__(self, variable: str | None = None): """Initialize the property with the name of the corresponding plot variable.""" if not variable: variable = self.__class__.__name__.lower() self.variable = variable def default_scale(self, data: Series) -> Scale: """Given data, initialize appropriate scale class.""" var_type = variable_type(data, boolean_type="boolean", strict_boolean=True) if var_type == "numeric": return Continuous() elif var_type == "datetime": return Temporal() elif var_type == "boolean": return Boolean() else: return Nominal() def infer_scale(self, arg: Any, data: Series) -> Scale: """Given data and a scaling argument, initialize appropriate scale class.""" # TODO put these somewhere external for validation # TODO putting this here won't pick it up if subclasses define infer_scale # (e.g. color). How best to handle that? One option is to call super after # handling property-specific possibilities (e.g. for color check that the # arg is not a valid palette name) but that could get tricky. trans_args = ["log", "symlog", "logit", "pow", "sqrt"] if isinstance(arg, str): if any(arg.startswith(k) for k in trans_args): # TODO validate numeric type? That should happen centrally somewhere return Continuous(trans=arg) else: msg = f"Unknown magic arg for {self.variable} scale: '{arg}'." raise ValueError(msg) else: arg_type = type(arg).__name__ msg = f"Magic arg for {self.variable} scale must be str, not {arg_type}." raise TypeError(msg) def get_mapping(self, scale: Scale, data: Series) -> Mapping: """Return a function that maps from data domain to property range.""" def identity(x): return x return identity def standardize(self, val: Any) -> Any: """Coerce flexible property value to standardized representation.""" return val def _check_dict_entries(self, levels: list, values: dict) -> None: """Input check when values are provided as a dictionary.""" missing = set(levels) - set(values) if missing: formatted = ", ".join(map(repr, sorted(missing, key=str))) err = f"No entry in {self.variable} dictionary for {formatted}" raise ValueError(err) def _check_list_length(self, levels: list, values: list) -> list: """Input check when values are provided as a list.""" message = "" if len(levels) > len(values): message = " ".join([ f"\nThe {self.variable} list has fewer values ({len(values)})", f"than needed ({len(levels)}) and will cycle, which may", "produce an uninterpretable plot." ]) values = [x for _, x in zip(levels, itertools.cycle(values))] elif len(values) > len(levels): message = " ".join([ f"The {self.variable} list has more values ({len(values)})", f"than needed ({len(levels)}), which may not be intended.", ]) values = values[:len(levels)] # TODO look into custom PlotSpecWarning with better formatting if message: warnings.warn(message, UserWarning) return values # =================================================================================== # # Properties relating to spatial position of marks on the plotting axes # =================================================================================== # class Coordinate(Property): """The position of visual marks with respect to the axes of the plot.""" legend = False normed = False # =================================================================================== # # Properties with numeric values where scale range can be defined as an interval # =================================================================================== # class IntervalProperty(Property): """A numeric property where scale range can be defined as an interval.""" legend = True normed = True _default_range: tuple[float, float] = (0, 1) @property def default_range(self) -> tuple[float, float]: """Min and max values used by default for semantic mapping.""" return self._default_range def _forward(self, values: ArrayLike) -> ArrayLike: """Transform applied to native values before linear mapping into interval.""" return values def _inverse(self, values: ArrayLike) -> ArrayLike: """Transform applied to results of mapping that returns to native values.""" return values def infer_scale(self, arg: Any, data: Series) -> Scale: """Given data and a scaling argument, initialize appropriate scale class.""" # TODO infer continuous based on log/sqrt etc? var_type = variable_type(data, boolean_type="boolean", strict_boolean=True) if var_type == "boolean": return Boolean(arg) elif isinstance(arg, (list, dict)): return Nominal(arg) elif var_type == "categorical": return Nominal(arg) elif var_type == "datetime": return Temporal(arg) # TODO other variable types else: return Continuous(arg) def get_mapping(self, scale: Scale, data: Series) -> Mapping: """Return a function that maps from data domain to property range.""" if isinstance(scale, Nominal): return self._get_nominal_mapping(scale, data) elif isinstance(scale, Boolean): return self._get_boolean_mapping(scale, data) if scale.values is None: vmin, vmax = self._forward(self.default_range) elif isinstance(scale.values, tuple) and len(scale.values) == 2: vmin, vmax = self._forward(scale.values) else: if isinstance(scale.values, tuple): actual = f"{len(scale.values)}-tuple" else: actual = str(type(scale.values)) scale_class = scale.__class__.__name__ err = " ".join([ f"Values for {self.variable} variables with {scale_class} scale", f"must be 2-tuple; not {actual}.", ]) raise TypeError(err) def mapping(x): return self._inverse(np.multiply(x, vmax - vmin) + vmin) return mapping def _get_nominal_mapping(self, scale: Nominal, data: Series) -> Mapping: """Identify evenly-spaced values using interval or explicit mapping.""" levels = categorical_order(data, scale.order) values = self._get_values(scale, levels) def mapping(x): ixs = np.asarray(x, np.intp) out = np.full(len(x), np.nan) use = np.isfinite(x) out[use] = np.take(values, ixs[use]) return out return mapping def _get_boolean_mapping(self, scale: Boolean, data: Series) -> Mapping: """Identify evenly-spaced values using interval or explicit mapping.""" values = self._get_values(scale, [True, False]) def mapping(x): out = np.full(len(x), np.nan) use = np.isfinite(x) out[use] = np.where(x[use], *values) return out return mapping def _get_values(self, scale: Scale, levels: list) -> list: """Validate scale.values and identify a value for each level.""" if isinstance(scale.values, dict): self._check_dict_entries(levels, scale.values) values = [scale.values[x] for x in levels] elif isinstance(scale.values, list): values = self._check_list_length(levels, scale.values) else: if scale.values is None: vmin, vmax = self.default_range elif isinstance(scale.values, tuple): vmin, vmax = scale.values else: scale_class = scale.__class__.__name__ err = " ".join([ f"Values for {self.variable} variables with {scale_class} scale", f"must be a dict, list or tuple; not {type(scale.values)}", ]) raise TypeError(err) vmin, vmax = self._forward([vmin, vmax]) values = list(self._inverse(np.linspace(vmax, vmin, len(levels)))) return values class PointSize(IntervalProperty): """Size (diameter) of a point mark, in points, with scaling by area.""" _default_range = 2, 8 # TODO use rcparams? def _forward(self, values): """Square native values to implement linear scaling of point area.""" return np.square(values) def _inverse(self, values): """Invert areal values back to point diameter.""" return np.sqrt(values) class LineWidth(IntervalProperty): """Thickness of a line mark, in points.""" @property def default_range(self) -> tuple[float, float]: """Min and max values used by default for semantic mapping.""" base = mpl.rcParams["lines.linewidth"] return base * .5, base * 2 class EdgeWidth(IntervalProperty): """Thickness of the edges on a patch mark, in points.""" @property def default_range(self) -> tuple[float, float]: """Min and max values used by default for semantic mapping.""" base = mpl.rcParams["patch.linewidth"] return base * .5, base * 2 class Stroke(IntervalProperty): """Thickness of lines that define point glyphs.""" _default_range = .25, 2.5 class Alpha(IntervalProperty): """Opacity of the color values for an arbitrary mark.""" _default_range = .3, .95 # TODO validate / enforce that output is in [0, 1] class Offset(IntervalProperty): """Offset for edge-aligned text, in point units.""" _default_range = 0, 5 _legend = False class FontSize(IntervalProperty): """Font size for textual marks, in points.""" _legend = False @property def default_range(self) -> tuple[float, float]: """Min and max values used by default for semantic mapping.""" base = mpl.rcParams["font.size"] return base * .5, base * 2 # =================================================================================== # # Properties defined by arbitrary objects with inherently nominal scaling # =================================================================================== # class ObjectProperty(Property): """A property defined by arbitrary an object, with inherently nominal scaling.""" legend = True normed = False # Object representing null data, should appear invisible when drawn by matplotlib # Note that we now drop nulls in Plot._plot_layer and thus may not need this null_value: Any = None def _default_values(self, n: int) -> list: raise NotImplementedError() def default_scale(self, data: Series) -> Scale: var_type = variable_type(data, boolean_type="boolean", strict_boolean=True) return Boolean() if var_type == "boolean" else Nominal() def infer_scale(self, arg: Any, data: Series) -> Scale: var_type = variable_type(data, boolean_type="boolean", strict_boolean=True) return Boolean(arg) if var_type == "boolean" else Nominal(arg) def get_mapping(self, scale: Scale, data: Series) -> Mapping: """Define mapping as lookup into list of object values.""" boolean_scale = isinstance(scale, Boolean) order = getattr(scale, "order", [True, False] if boolean_scale else None) levels = categorical_order(data, order) values = self._get_values(scale, levels) if boolean_scale: values = values[::-1] def mapping(x): ixs = np.asarray(np.nan_to_num(x), np.intp) return [ values[ix] if np.isfinite(x_i) else self.null_value for x_i, ix in zip(x, ixs) ] return mapping def _get_values(self, scale: Scale, levels: list) -> list: """Validate scale.values and identify a value for each level.""" n = len(levels) if isinstance(scale.values, dict): self._check_dict_entries(levels, scale.values) values = [scale.values[x] for x in levels] elif isinstance(scale.values, list): values = self._check_list_length(levels, scale.values) elif scale.values is None: values = self._default_values(n) else: msg = " ".join([ f"Scale values for a {self.variable} variable must be provided", f"in a dict or list; not {type(scale.values)}." ]) raise TypeError(msg) values = [self.standardize(x) for x in values] return values class Marker(ObjectProperty): """Shape of points in scatter-type marks or lines with data points marked.""" null_value = MarkerStyle("") # TODO should we have named marker "palettes"? (e.g. see d3 options) # TODO need some sort of "require_scale" functionality # to raise when we get the wrong kind explicitly specified def standardize(self, val: MarkerPattern) -> MarkerStyle: return MarkerStyle(val) def _default_values(self, n: int) -> list[MarkerStyle]: """Build an arbitrarily long list of unique marker styles. Parameters ---------- n : int Number of unique marker specs to generate. Returns ------- markers : list of string or tuples Values for defining :class:`matplotlib.markers.MarkerStyle` objects. All markers will be filled. """ # Start with marker specs that are well distinguishable markers = [ "o", "X", (4, 0, 45), "P", (4, 0, 0), (4, 1, 0), "^", (4, 1, 45), "v", ] # Now generate more from regular polygons of increasing order s = 5 while len(markers) < n: a = 360 / (s + 1) / 2 markers.extend([(s + 1, 1, a), (s + 1, 0, a), (s, 1, 0), (s, 0, 0)]) s += 1 markers = [MarkerStyle(m) for m in markers[:n]] return markers class LineStyle(ObjectProperty): """Dash pattern for line-type marks.""" null_value = "" def standardize(self, val: str | DashPattern) -> DashPatternWithOffset: return self._get_dash_pattern(val) def _default_values(self, n: int) -> list[DashPatternWithOffset]: """Build an arbitrarily long list of unique dash styles for lines. Parameters ---------- n : int Number of unique dash specs to generate. Returns ------- dashes : list of strings or tuples Valid arguments for the ``dashes`` parameter on :class:`matplotlib.lines.Line2D`. The first spec is a solid line (``""``), the remainder are sequences of long and short dashes. """ # Start with dash specs that are well distinguishable dashes: list[str | DashPattern] = [ "-", (4, 1.5), (1, 1), (3, 1.25, 1.5, 1.25), (5, 1, 1, 1), ] # Now programmatically build as many as we need p = 3 while len(dashes) < n: # Take combinations of long and short dashes a = itertools.combinations_with_replacement([3, 1.25], p) b = itertools.combinations_with_replacement([4, 1], p) # Interleave the combinations, reversing one of the streams segment_list = itertools.chain(*zip(list(a)[1:-1][::-1], list(b)[1:-1])) # Now insert the gaps for segments in segment_list: gap = min(segments) spec = tuple(itertools.chain(*((seg, gap) for seg in segments))) dashes.append(spec) p += 1 return [self._get_dash_pattern(x) for x in dashes] @staticmethod def _get_dash_pattern(style: str | DashPattern) -> DashPatternWithOffset: """Convert linestyle arguments to dash pattern with offset.""" # Copied and modified from Matplotlib 3.4 # go from short hand -> full strings ls_mapper = {"-": "solid", "--": "dashed", "-.": "dashdot", ":": "dotted"} if isinstance(style, str): style = ls_mapper.get(style, style) # un-dashed styles if style in ["solid", "none", "None"]: offset = 0 dashes = None # dashed styles elif style in ["dashed", "dashdot", "dotted"]: offset = 0 dashes = tuple(mpl.rcParams[f"lines.{style}_pattern"]) else: options = [*ls_mapper.values(), *ls_mapper.keys()] msg = f"Linestyle string must be one of {options}, not {repr(style)}." raise ValueError(msg) elif isinstance(style, tuple): if len(style) > 1 and isinstance(style[1], tuple): offset, dashes = style elif len(style) > 1 and style[1] is None: offset, dashes = style else: offset = 0 dashes = style else: val_type = type(style).__name__ msg = f"Linestyle must be str or tuple, not {val_type}." raise TypeError(msg) # Normalize offset to be positive and shorter than the dash cycle if dashes is not None: try: dsum = sum(dashes) except TypeError as err: msg = f"Invalid dash pattern: {dashes}" raise TypeError(msg) from err if dsum: offset %= dsum return offset, dashes class TextAlignment(ObjectProperty): legend = False class HorizontalAlignment(TextAlignment): def _default_values(self, n: int) -> list: vals = itertools.cycle(["left", "right"]) return [next(vals) for _ in range(n)] class VerticalAlignment(TextAlignment): def _default_values(self, n: int) -> list: vals = itertools.cycle(["top", "bottom"]) return [next(vals) for _ in range(n)] # =================================================================================== # # Properties with RGB(A) color values # =================================================================================== # class Color(Property): """Color, as RGB(A), scalable with nominal palettes or continuous gradients.""" legend = True normed = True def standardize(self, val: ColorSpec) -> RGBTuple | RGBATuple: # Return color with alpha channel only if the input spec has it # This is so that RGBA colors can override the Alpha property if to_rgba(val) != to_rgba(val, 1): return to_rgba(val) else: return to_rgb(val) def _standardize_color_sequence(self, colors: ArrayLike) -> ArrayLike: """Convert color sequence to RGB(A) array, preserving but not adding alpha.""" def has_alpha(x): return to_rgba(x) != to_rgba(x, 1) if isinstance(colors, np.ndarray): needs_alpha = colors.shape[1] == 4 else: needs_alpha = any(has_alpha(x) for x in colors) if needs_alpha: return to_rgba_array(colors) else: return to_rgba_array(colors)[:, :3] def infer_scale(self, arg: Any, data: Series) -> Scale: # TODO when inferring Continuous without data, verify type # TODO need to rethink the variable type system # (e.g. boolean, ordered categories as Ordinal, etc).. var_type = variable_type(data, boolean_type="boolean", strict_boolean=True) if var_type == "boolean": return Boolean(arg) if isinstance(arg, (dict, list)): return Nominal(arg) if isinstance(arg, tuple): if var_type == "categorical": # TODO It seems reasonable to allow a gradient mapping for nominal # scale but it also feels "technically" wrong. Should this infer # Ordinal with categorical data and, if so, verify orderedness? return Nominal(arg) return Continuous(arg) if callable(arg): return Continuous(arg) # TODO Do we accept str like "log", "pow", etc. for semantics? if not isinstance(arg, str): msg = " ".join([ f"A single scale argument for {self.variable} variables must be", f"a string, dict, tuple, list, or callable, not {type(arg)}." ]) raise TypeError(msg) if arg in QUAL_PALETTES: return Nominal(arg) elif var_type == "numeric": return Continuous(arg) # TODO implement scales for date variables and any others. else: return Nominal(arg) def get_mapping(self, scale: Scale, data: Series) -> Mapping: """Return a function that maps from data domain to color values.""" # TODO what is best way to do this conditional? # Should it be class-based or should classes have behavioral attributes? if isinstance(scale, Nominal): return self._get_nominal_mapping(scale, data) elif isinstance(scale, Boolean): return self._get_boolean_mapping(scale, data) if scale.values is None: # TODO Rethink best default continuous color gradient mapping = color_palette("ch:", as_cmap=True) elif isinstance(scale.values, tuple): # TODO blend_palette will strip alpha, but we should support # interpolation on all four channels mapping = blend_palette(scale.values, as_cmap=True) elif isinstance(scale.values, str): # TODO for matplotlib colormaps this will clip extremes, which is # different from what using the named colormap directly would do # This may or may not be desireable. mapping = color_palette(scale.values, as_cmap=True) elif callable(scale.values): mapping = scale.values else: scale_class = scale.__class__.__name__ msg = " ".join([ f"Scale values for {self.variable} with a {scale_class} mapping", f"must be string, tuple, or callable; not {type(scale.values)}." ]) raise TypeError(msg) def _mapping(x): # Remove alpha channel so it does not override alpha property downstream # TODO this will need to be more flexible to support RGBA tuples (see above) invalid = ~np.isfinite(x) out = mapping(x)[:, :3] out[invalid] = np.nan return out return _mapping def _get_nominal_mapping(self, scale: Nominal, data: Series) -> Mapping: levels = categorical_order(data, scale.order) colors = self._get_values(scale, levels) def mapping(x): ixs = np.asarray(np.nan_to_num(x), np.intp) use = np.isfinite(x) out = np.full((len(ixs), colors.shape[1]), np.nan) out[use] = np.take(colors, ixs[use], axis=0) return out return mapping def _get_boolean_mapping(self, scale: Boolean, data: Series) -> Mapping: colors = self._get_values(scale, [True, False]) def mapping(x): use = np.isfinite(x) x = np.asarray(np.nan_to_num(x)).astype(bool) out = np.full((len(x), colors.shape[1]), np.nan) out[x & use] = colors[0] out[~x & use] = colors[1] return out return mapping def _get_values(self, scale: Scale, levels: list) -> ArrayLike: """Validate scale.values and identify a value for each level.""" n = len(levels) values = scale.values if isinstance(values, dict): self._check_dict_entries(levels, values) colors = [values[x] for x in levels] elif isinstance(values, list): colors = self._check_list_length(levels, values) elif isinstance(values, tuple): colors = blend_palette(values, n) elif isinstance(values, str): colors = color_palette(values, n) elif values is None: if n <= len(get_color_cycle()): # Use current (global) default palette colors = color_palette(n_colors=n) else: colors = color_palette("husl", n) else: scale_class = scale.__class__.__name__ msg = " ".join([ f"Scale values for {self.variable} with a {scale_class} mapping", f"must be string, list, tuple, or dict; not {type(scale.values)}." ]) raise TypeError(msg) return self._standardize_color_sequence(colors) # =================================================================================== # # Properties that can take only two states # =================================================================================== # class Fill(Property): """Boolean property of points/bars/patches that can be solid or outlined.""" legend = True normed = False def default_scale(self, data: Series) -> Scale: var_type = variable_type(data, boolean_type="boolean", strict_boolean=True) return Boolean() if var_type == "boolean" else Nominal() def infer_scale(self, arg: Any, data: Series) -> Scale: var_type = variable_type(data, boolean_type="boolean", strict_boolean=True) return Boolean(arg) if var_type == "boolean" else Nominal(arg) def standardize(self, val: Any) -> bool: return bool(val) def _default_values(self, n: int) -> list: """Return a list of n values, alternating True and False.""" if n > 2: msg = " ".join([ f"The variable assigned to {self.variable} has more than two levels,", f"so {self.variable} values will cycle and may be uninterpretable", ]) # TODO fire in a "nice" way (see above) warnings.warn(msg, UserWarning) return [x for x, _ in zip(itertools.cycle([True, False]), range(n))] def get_mapping(self, scale: Scale, data: Series) -> Mapping: """Return a function that maps each data value to True or False.""" boolean_scale = isinstance(scale, Boolean) order = getattr(scale, "order", [True, False] if boolean_scale else None) levels = categorical_order(data, order) values = self._get_values(scale, levels) if boolean_scale: values = values[::-1] def mapping(x): ixs = np.asarray(np.nan_to_num(x), np.intp) return [ values[ix] if np.isfinite(x_i) else False for x_i, ix in zip(x, ixs) ] return mapping def _get_values(self, scale: Scale, levels: list) -> list: """Validate scale.values and identify a value for each level.""" if isinstance(scale.values, list): values = [bool(x) for x in scale.values] elif isinstance(scale.values, dict): values = [bool(scale.values[x]) for x in levels] elif scale.values is None: values = self._default_values(len(levels)) else: msg = " ".join([ f"Scale values for {self.variable} must be passed in", f"a list or dict; not {type(scale.values)}." ]) raise TypeError(msg) return values # =================================================================================== # # Enumeration of properties for use by Plot and Mark classes # =================================================================================== # # TODO turn this into a property registry with hooks, etc. # TODO Users do not interact directly with properties, so how to document them? PROPERTY_CLASSES = { "x": Coordinate, "y": Coordinate, "color": Color, "alpha": Alpha, "fill": Fill, "marker": Marker, "pointsize": PointSize, "stroke": Stroke, "linewidth": LineWidth, "linestyle": LineStyle, "fillcolor": Color, "fillalpha": Alpha, "edgewidth": EdgeWidth, "edgestyle": LineStyle, "edgecolor": Color, "edgealpha": Alpha, "text": Property, "halign": HorizontalAlignment, "valign": VerticalAlignment, "offset": Offset, "fontsize": FontSize, "xmin": Coordinate, "xmax": Coordinate, "ymin": Coordinate, "ymax": Coordinate, "group": Property, # TODO pattern? # TODO gradient? } PROPERTIES = {var: cls(var) for var, cls in PROPERTY_CLASSES.items()} ================================================ FILE: seaborn/_core/rules.py ================================================ from __future__ import annotations import warnings from collections import UserString from numbers import Number from datetime import datetime import numpy as np import pandas as pd from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Literal from pandas import Series class VarType(UserString): """ Prevent comparisons elsewhere in the library from using the wrong name. Errors are simple assertions because users should not be able to trigger them. If that changes, they should be more verbose. """ # TODO VarType is an awfully overloaded name, but so is DataType ... # TODO adding unknown because we are using this in for scales, is that right? allowed = "numeric", "datetime", "categorical", "boolean", "unknown" def __init__(self, data): assert data in self.allowed, data super().__init__(data) def __eq__(self, other): assert other in self.allowed, other return self.data == other def variable_type( vector: Series, boolean_type: Literal["numeric", "categorical", "boolean"] = "numeric", strict_boolean: bool = False, ) -> VarType: """ Determine whether a vector contains numeric, categorical, or datetime data. This function differs from the pandas typing API in a few ways: - Python sequences or object-typed PyData objects are considered numeric if all of their entries are numeric. - String or mixed-type data are considered categorical even if not explicitly represented as a :class:`pandas.api.types.CategoricalDtype`. - There is some flexibility about how to treat binary / boolean data. Parameters ---------- vector : :func:`pandas.Series`, :func:`numpy.ndarray`, or Python sequence Input data to test. boolean_type : 'numeric', 'categorical', or 'boolean' Type to use for vectors containing only 0s and 1s (and NAs). strict_boolean : bool If True, only consider data to be boolean when the dtype is bool or Boolean. Returns ------- var_type : 'numeric', 'categorical', or 'datetime' Name identifying the type of data in the vector. """ # If a categorical dtype is set, infer categorical if isinstance(getattr(vector, 'dtype', None), pd.CategoricalDtype): return VarType("categorical") # Special-case all-na data, which is always "numeric" if pd.isna(vector).all(): return VarType("numeric") # Now drop nulls to simplify further type inference vector = vector.dropna() # Special-case binary/boolean data, allow caller to determine # This triggers a numpy warning when vector has strings/objects # https://github.com/numpy/numpy/issues/6784 # Because we reduce with .all(), we are agnostic about whether the # comparison returns a scalar or vector, so we will ignore the warning. # It triggers a separate DeprecationWarning when the vector has datetimes: # https://github.com/numpy/numpy/issues/13548 # This is considered a bug by numpy and will likely go away. with warnings.catch_warnings(): warnings.simplefilter( action='ignore', category=(FutureWarning, DeprecationWarning) # type: ignore # mypy bug? ) if strict_boolean: if isinstance(vector.dtype, pd.core.dtypes.base.ExtensionDtype): boolean_dtypes = ["bool", "boolean"] else: boolean_dtypes = ["bool"] boolean_vector = vector.dtype in boolean_dtypes else: try: boolean_vector = bool(np.isin(vector, [0, 1]).all()) except TypeError: # .isin comparison is not guaranteed to be possible under NumPy # casting rules, depending on the (unknown) dtype of 'vector' boolean_vector = False if boolean_vector: return VarType(boolean_type) # Defer to positive pandas tests if pd.api.types.is_numeric_dtype(vector): return VarType("numeric") if pd.api.types.is_datetime64_dtype(vector): return VarType("datetime") # --- If we get to here, we need to check the entries # Check for a collection where everything is a number def all_numeric(x): for x_i in x: if not isinstance(x_i, Number): return False return True if all_numeric(vector): return VarType("numeric") # Check for a collection where everything is a datetime def all_datetime(x): for x_i in x: if not isinstance(x_i, (datetime, np.datetime64)): return False return True if all_datetime(vector): return VarType("datetime") # Otherwise, our final fallback is to consider things categorical return VarType("categorical") def categorical_order(vector: Series, order: list | None = None) -> list: """ Return a list of unique data values using seaborn's ordering rules. Parameters ---------- vector : Series Vector of "categorical" values order : list Desired order of category levels to override the order determined from the `data` object. Returns ------- order : list Ordered list of category levels not including null values. """ if order is not None: return order if vector.dtype.name == "category": order = list(vector.cat.categories) else: order = list(filter(pd.notnull, vector.unique())) if variable_type(pd.Series(order)) == "numeric": order.sort() return order ================================================ FILE: seaborn/_core/scales.py ================================================ from __future__ import annotations import re from copy import copy from collections.abc import Sequence from dataclasses import dataclass from functools import partial from typing import Any, Callable, Tuple, Optional, ClassVar import numpy as np import matplotlib as mpl from matplotlib.ticker import ( Locator, Formatter, AutoLocator, AutoMinorLocator, FixedLocator, LinearLocator, LogLocator, SymmetricalLogLocator, MaxNLocator, MultipleLocator, EngFormatter, FuncFormatter, LogFormatterSciNotation, ScalarFormatter, StrMethodFormatter, ) from matplotlib.dates import ( AutoDateLocator, AutoDateFormatter, ConciseDateFormatter, ) from matplotlib.axis import Axis from matplotlib.scale import ScaleBase from pandas import Series from seaborn._core.rules import categorical_order from seaborn._core.typing import Default, default from typing import TYPE_CHECKING if TYPE_CHECKING: from seaborn._core.plot import Plot from seaborn._core.properties import Property from numpy.typing import ArrayLike, NDArray TransFuncs = Tuple[ Callable[[ArrayLike], ArrayLike], Callable[[ArrayLike], ArrayLike] ] # TODO Reverting typing to Any as it was proving too complicated to # work out the right way to communicate the types to mypy. Revisit! Pipeline = Sequence[Optional[Callable[[Any], Any]]] class Scale: """Base class for objects that map data values to visual properties.""" values: tuple | str | list | dict | None _priority: ClassVar[int] _pipeline: Pipeline _matplotlib_scale: ScaleBase _spacer: staticmethod _legend: tuple[list[Any], list[str]] | None def __post_init__(self): self._tick_params = None self._label_params = None self._legend = None def tick(self): raise NotImplementedError() def label(self): raise NotImplementedError() def _get_locators(self): raise NotImplementedError() def _get_formatter(self, locator: Locator | None = None): raise NotImplementedError() def _get_scale(self, name: str, forward: Callable, inverse: Callable): major_locator, minor_locator = self._get_locators(**self._tick_params) major_formatter = self._get_formatter(major_locator, **self._label_params) class InternalScale(mpl.scale.FuncScale): def set_default_locators_and_formatters(self, axis): axis.set_major_locator(major_locator) if minor_locator is not None: axis.set_minor_locator(minor_locator) axis.set_major_formatter(major_formatter) return InternalScale(name, (forward, inverse)) def _spacing(self, x: Series) -> float: space = self._spacer(x) if np.isnan(space): # This happens when there is no variance in the orient coordinate data # Not exactly clear what the right default is, but 1 seems reasonable? return 1 return space def _setup( self, data: Series, prop: Property, axis: Axis | None = None, ) -> Scale: raise NotImplementedError() def _finalize(self, p: Plot, axis: Axis) -> None: """Perform scale-specific axis tweaks after adding artists.""" pass def __call__(self, data: Series) -> ArrayLike: trans_data: Series | NDArray | list # TODO sometimes we need to handle scalars (e.g. for Line) # but what is the best way to do that? scalar_data = np.isscalar(data) if scalar_data: trans_data = np.array([data]) else: trans_data = data for func in self._pipeline: if func is not None: trans_data = func(trans_data) if scalar_data: return trans_data[0] else: return trans_data @staticmethod def _identity(): class Identity(Scale): _pipeline = [] _spacer = None _legend = None _matplotlib_scale = None return Identity() @dataclass class Boolean(Scale): """ A scale with a discrete domain of True and False values. The behavior is similar to the :class:`Nominal` scale, but property mappings and legends will use a [True, False] ordering rather than a sort using numeric rules. Coordinate variables accomplish this by inverting axis limits so as to maintain underlying numeric positioning. Input data are cast to boolean values, respecting missing data. """ values: tuple | list | dict | None = None _priority: ClassVar[int] = 3 def _setup( self, data: Series, prop: Property, axis: Axis | None = None, ) -> Scale: new = copy(self) if new._tick_params is None: new = new.tick() if new._label_params is None: new = new.label() def na_safe_cast(x): # TODO this doesn't actually need to be a closure if np.isscalar(x): return float(bool(x)) else: if hasattr(x, "notna"): # Handle pd.NA; np<>pd interop with NA is tricky use = x.notna().to_numpy() else: use = np.isfinite(x) out = np.full(len(x), np.nan, dtype=float) out[use] = x[use].astype(bool).astype(float) return out new._pipeline = [na_safe_cast, prop.get_mapping(new, data)] new._spacer = _default_spacer if prop.legend: new._legend = [True, False], ["True", "False"] forward, inverse = _make_identity_transforms() mpl_scale = new._get_scale(str(data.name), forward, inverse) axis = PseudoAxis(mpl_scale) if axis is None else axis mpl_scale.set_default_locators_and_formatters(axis) new._matplotlib_scale = mpl_scale return new def _finalize(self, p: Plot, axis: Axis) -> None: # We want values to appear in a True, False order but also want # True/False to be drawn at 1/0 positions respectively to avoid nasty # surprises if additional artists are added through the matplotlib API. # We accomplish this using axis inversion akin to what we do in Nominal. ax = axis.axes name = axis.axis_name axis.grid(False, which="both") if name not in p._limits: nticks = len(axis.get_major_ticks()) lo, hi = -.5, nticks - .5 if name == "x": lo, hi = hi, lo set_lim = getattr(ax, f"set_{name}lim") set_lim(lo, hi, auto=None) def tick(self, locator: Locator | None = None): new = copy(self) new._tick_params = {"locator": locator} return new def label(self, formatter: Formatter | None = None): new = copy(self) new._label_params = {"formatter": formatter} return new def _get_locators(self, locator): if locator is not None: return locator return FixedLocator([0, 1]), None def _get_formatter(self, locator, formatter): if formatter is not None: return formatter return FuncFormatter(lambda x, _: str(bool(x))) @dataclass class Nominal(Scale): """ A categorical scale without relative importance / magnitude. """ # Categorical (convert to strings), un-sortable values: tuple | str | list | dict | None = None order: list | None = None _priority: ClassVar[int] = 4 def _setup( self, data: Series, prop: Property, axis: Axis | None = None, ) -> Scale: new = copy(self) if new._tick_params is None: new = new.tick() if new._label_params is None: new = new.label() # TODO flexibility over format() which isn't great for numbers / dates stringify = np.vectorize(format, otypes=["object"]) units_seed = categorical_order(data, new.order) # TODO move to Nominal._get_scale? # TODO this needs some more complicated rethinking about how to pass # a unit dictionary down to these methods, along with how much we want # to invest in their API. What is it useful for tick() to do here? # (Ordinal may be different if we draw that contrast). # Any customization we do to allow, e.g., label wrapping will probably # require defining our own Formatter subclass. # We could also potentially implement auto-wrapping in an Axis subclass # (see Axis.draw ... it already is computing the bboxes). # major_locator, minor_locator = new._get_locators(**new._tick_params) # major_formatter = new._get_formatter(major_locator, **new._label_params) class CatScale(mpl.scale.LinearScale): def set_default_locators_and_formatters(self, axis): ... # axis.set_major_locator(major_locator) # if minor_locator is not None: # axis.set_minor_locator(minor_locator) # axis.set_major_formatter(major_formatter) mpl_scale = CatScale(data.name) if axis is None: axis = PseudoAxis(mpl_scale) # TODO Currently just used in non-Coordinate contexts, but should # we use this to (A) set the padding we want for categorial plots # and (B) allow the values parameter for a Coordinate to set xlim/ylim axis.set_view_interval(0, len(units_seed) - 1) new._matplotlib_scale = mpl_scale # TODO array cast necessary to handle float/int mixture, which we need # to solve in a more systematic way probably # (i.e. if we have [1, 2.5], do we want [1.0, 2.5]? Unclear) axis.update_units(stringify(np.array(units_seed))) # TODO define this more centrally def convert_units(x): # TODO only do this with explicit order? # (But also category dtype?) # TODO isin fails when units_seed mixes numbers and strings (numpy error?) # but np.isin also does not seem any faster? (Maybe not broadcasting in C) # keep = x.isin(units_seed) keep = np.array([x_ in units_seed for x_ in x], bool) out = np.full(len(x), np.nan) out[keep] = axis.convert_units(stringify(x[keep])) return out new._pipeline = [convert_units, prop.get_mapping(new, data)] new._spacer = _default_spacer if prop.legend: new._legend = units_seed, list(stringify(units_seed)) return new def _finalize(self, p: Plot, axis: Axis) -> None: ax = axis.axes name = axis.axis_name axis.grid(False, which="both") if name not in p._limits: nticks = len(axis.get_major_ticks()) lo, hi = -.5, nticks - .5 if name == "y": lo, hi = hi, lo set_lim = getattr(ax, f"set_{name}lim") set_lim(lo, hi, auto=None) def tick(self, locator: Locator | None = None) -> Nominal: """ Configure the selection of ticks for the scale's axis or legend. .. note:: This API is under construction and will be enhanced over time. At the moment, it is probably not very useful. Parameters ---------- locator : :class:`matplotlib.ticker.Locator` subclass Pre-configured matplotlib locator; other parameters will not be used. Returns ------- Copy of self with new tick configuration. """ new = copy(self) new._tick_params = {"locator": locator} return new def label(self, formatter: Formatter | None = None) -> Nominal: """ Configure the selection of labels for the scale's axis or legend. .. note:: This API is under construction and will be enhanced over time. At the moment, it is probably not very useful. Parameters ---------- formatter : :class:`matplotlib.ticker.Formatter` subclass Pre-configured matplotlib formatter; other parameters will not be used. Returns ------- scale Copy of self with new tick configuration. """ new = copy(self) new._label_params = {"formatter": formatter} return new def _get_locators(self, locator): if locator is not None: return locator, None locator = mpl.category.StrCategoryLocator({}) return locator, None def _get_formatter(self, locator, formatter): if formatter is not None: return formatter formatter = mpl.category.StrCategoryFormatter({}) return formatter @dataclass class Ordinal(Scale): # Categorical (convert to strings), sortable, can skip ticklabels ... @dataclass class Discrete(Scale): # Numeric, integral, can skip ticks/ticklabels ... @dataclass class ContinuousBase(Scale): values: tuple | str | None = None norm: tuple | None = None def _setup( self, data: Series, prop: Property, axis: Axis | None = None, ) -> Scale: new = copy(self) if new._tick_params is None: new = new.tick() if new._label_params is None: new = new.label() forward, inverse = new._get_transform() mpl_scale = new._get_scale(str(data.name), forward, inverse) if axis is None: axis = PseudoAxis(mpl_scale) axis.update_units(data) mpl_scale.set_default_locators_and_formatters(axis) new._matplotlib_scale = mpl_scale normalize: Optional[Callable[[ArrayLike], ArrayLike]] if prop.normed: if new.norm is None: vmin, vmax = data.min(), data.max() else: vmin, vmax = new.norm vmin, vmax = map(float, axis.convert_units((vmin, vmax))) a = forward(vmin) b = forward(vmax) - forward(vmin) def normalize(x): return (x - a) / b else: normalize = vmin = vmax = None new._pipeline = [ axis.convert_units, forward, normalize, prop.get_mapping(new, data) ] def spacer(x): x = x.dropna().unique() if len(x) < 2: return np.nan return np.min(np.diff(np.sort(x))) new._spacer = spacer # TODO How to allow disabling of legend for all uses of property? # Could add a Scale parameter, or perhaps Scale.suppress()? # Are there other useful parameters that would be in Scale.legend() # besides allowing Scale.legend(False)? if prop.legend: axis.set_view_interval(vmin, vmax) locs = axis.major.locator() locs = locs[(vmin <= locs) & (locs <= vmax)] # Avoid having an offset / scientific notation in a legend # as we don't represent that anywhere so it ends up incorrect. # This could become an option (e.g. Continuous.label(offset=True)) # in which case we would need to figure out how to show it. if hasattr(axis.major.formatter, "set_useOffset"): axis.major.formatter.set_useOffset(False) if hasattr(axis.major.formatter, "set_scientific"): axis.major.formatter.set_scientific(False) labels = axis.major.formatter.format_ticks(locs) new._legend = list(locs), list(labels) return new def _get_transform(self): arg = self.trans def get_param(method, default): if arg == method: return default return float(arg[len(method):]) if arg is None: return _make_identity_transforms() elif isinstance(arg, tuple): return arg elif isinstance(arg, str): if arg == "ln": return _make_log_transforms() elif arg == "logit": base = get_param("logit", 10) return _make_logit_transforms(base) elif arg.startswith("log"): base = get_param("log", 10) return _make_log_transforms(base) elif arg.startswith("symlog"): c = get_param("symlog", 1) return _make_symlog_transforms(c) elif arg.startswith("pow"): exp = get_param("pow", 2) return _make_power_transforms(exp) elif arg == "sqrt": return _make_sqrt_transforms() else: raise ValueError(f"Unknown value provided for trans: {arg!r}") @dataclass class Continuous(ContinuousBase): """ A numeric scale supporting norms and functional transforms. """ values: tuple | str | None = None trans: str | TransFuncs | None = None # TODO Add this to deal with outliers? # outside: Literal["keep", "drop", "clip"] = "keep" _priority: ClassVar[int] = 1 def tick( self, locator: Locator | None = None, *, at: Sequence[float] | None = None, upto: int | None = None, count: int | None = None, every: float | None = None, between: tuple[float, float] | None = None, minor: int | None = None, ) -> Continuous: """ Configure the selection of ticks for the scale's axis or legend. Parameters ---------- locator : :class:`matplotlib.ticker.Locator` subclass Pre-configured matplotlib locator; other parameters will not be used. at : sequence of floats Place ticks at these specific locations (in data units). upto : int Choose "nice" locations for ticks, but do not exceed this number. count : int Choose exactly this number of ticks, bounded by `between` or axis limits. every : float Choose locations at this interval of separation (in data units). between : pair of floats Bound upper / lower ticks when using `every` or `count`. minor : int Number of unlabeled ticks to draw between labeled "major" ticks. Returns ------- scale Copy of self with new tick configuration. """ # Input checks if locator is not None and not isinstance(locator, Locator): raise TypeError( f"Tick locator must be an instance of {Locator!r}, " f"not {type(locator)!r}." ) log_base, symlog_thresh = self._parse_for_log_params(self.trans) if log_base or symlog_thresh: if count is not None and between is None: raise RuntimeError("`count` requires `between` with log transform.") if every is not None: raise RuntimeError("`every` not supported with log transform.") new = copy(self) new._tick_params = { "locator": locator, "at": at, "upto": upto, "count": count, "every": every, "between": between, "minor": minor, } return new def label( self, formatter: Formatter | None = None, *, like: str | Callable | None = None, base: int | None | Default = default, unit: str | None = None, ) -> Continuous: """ Configure the appearance of tick labels for the scale's axis or legend. Parameters ---------- formatter : :class:`matplotlib.ticker.Formatter` subclass Pre-configured formatter to use; other parameters will be ignored. like : str or callable Either a format pattern (e.g., `".2f"`), a format string with fields named `x` and/or `pos` (e.g., `"${x:.2f}"`), or a callable with a signature like `f(x: float, pos: int) -> str`. In the latter variants, `x` is passed as the tick value and `pos` is passed as the tick index. base : number Use log formatter (with scientific notation) having this value as the base. Set to `None` to override the default formatter with a log transform. unit : str or (str, str) tuple Use SI prefixes with these units (e.g., with `unit="g"`, a tick value of 5000 will appear as `5 kg`). When a tuple, the first element gives the separator between the number and unit. Returns ------- scale Copy of self with new label configuration. """ # Input checks if formatter is not None and not isinstance(formatter, Formatter): raise TypeError( f"Label formatter must be an instance of {Formatter!r}, " f"not {type(formatter)!r}" ) if like is not None and not (isinstance(like, str) or callable(like)): msg = f"`like` must be a string or callable, not {type(like).__name__}." raise TypeError(msg) new = copy(self) new._label_params = { "formatter": formatter, "like": like, "base": base, "unit": unit, } return new def _parse_for_log_params( self, trans: str | TransFuncs | None ) -> tuple[float | None, float | None]: log_base = symlog_thresh = None if isinstance(trans, str): m = re.match(r"^log(\d*)", trans) if m is not None: log_base = float(m[1] or 10) m = re.match(r"symlog(\d*)", trans) if m is not None: symlog_thresh = float(m[1] or 1) return log_base, symlog_thresh def _get_locators(self, locator, at, upto, count, every, between, minor): log_base, symlog_thresh = self._parse_for_log_params(self.trans) if locator is not None: major_locator = locator elif upto is not None: if log_base: major_locator = LogLocator(base=log_base, numticks=upto) else: major_locator = MaxNLocator(upto, steps=[1, 1.5, 2, 2.5, 3, 5, 10]) elif count is not None: if between is None: # This is rarely useful (unless you are setting limits) major_locator = LinearLocator(count) else: if log_base or symlog_thresh: forward, inverse = self._get_transform() lo, hi = forward(between) ticks = inverse(np.linspace(lo, hi, num=count)) else: ticks = np.linspace(*between, num=count) major_locator = FixedLocator(ticks) elif every is not None: if between is None: major_locator = MultipleLocator(every) else: lo, hi = between ticks = np.arange(lo, hi + every, every) major_locator = FixedLocator(ticks) elif at is not None: major_locator = FixedLocator(at) else: if log_base: major_locator = LogLocator(log_base) elif symlog_thresh: major_locator = SymmetricalLogLocator(linthresh=symlog_thresh, base=10) else: major_locator = AutoLocator() if minor is None: minor_locator = LogLocator(log_base, subs=None) if log_base else None else: if log_base: subs = np.linspace(0, log_base, minor + 2)[1:-1] minor_locator = LogLocator(log_base, subs=subs) else: minor_locator = AutoMinorLocator(minor + 1) return major_locator, minor_locator def _get_formatter(self, locator, formatter, like, base, unit): log_base, symlog_thresh = self._parse_for_log_params(self.trans) if base is default: if symlog_thresh: log_base = 10 base = log_base if formatter is not None: return formatter if like is not None: if isinstance(like, str): if "{x" in like or "{pos" in like: fmt = like else: fmt = f"{{x:{like}}}" formatter = StrMethodFormatter(fmt) else: formatter = FuncFormatter(like) elif base is not None: # We could add other log options if necessary formatter = LogFormatterSciNotation(base) elif unit is not None: if isinstance(unit, tuple): sep, unit = unit elif not unit: sep = "" else: sep = " " formatter = EngFormatter(unit, sep=sep) else: formatter = ScalarFormatter() return formatter @dataclass class Temporal(ContinuousBase): """ A scale for date/time data. """ # TODO date: bool? # For when we only care about the time component, would affect # default formatter and norm conversion. Should also happen in # Property.default_scale. The alternative was having distinct # Calendric / Temporal scales, but that feels a bit fussy, and it # would get in the way of using first-letter shorthands because # Calendric and Continuous would collide. Still, we haven't implemented # those yet, and having a clear distinction betewen date(time) / time # may be more useful. trans = None _priority: ClassVar[int] = 2 def tick( self, locator: Locator | None = None, *, upto: int | None = None, ) -> Temporal: """ Configure the selection of ticks for the scale's axis or legend. .. note:: This API is under construction and will be enhanced over time. Parameters ---------- locator : :class:`matplotlib.ticker.Locator` subclass Pre-configured matplotlib locator; other parameters will not be used. upto : int Choose "nice" locations for ticks, but do not exceed this number. Returns ------- scale Copy of self with new tick configuration. """ if locator is not None and not isinstance(locator, Locator): err = ( f"Tick locator must be an instance of {Locator!r}, " f"not {type(locator)!r}." ) raise TypeError(err) new = copy(self) new._tick_params = {"locator": locator, "upto": upto} return new def label( self, formatter: Formatter | None = None, *, concise: bool = False, ) -> Temporal: """ Configure the appearance of tick labels for the scale's axis or legend. .. note:: This API is under construction and will be enhanced over time. Parameters ---------- formatter : :class:`matplotlib.ticker.Formatter` subclass Pre-configured formatter to use; other parameters will be ignored. concise : bool If True, use :class:`matplotlib.dates.ConciseDateFormatter` to make the tick labels as compact as possible. Returns ------- scale Copy of self with new label configuration. """ new = copy(self) new._label_params = {"formatter": formatter, "concise": concise} return new def _get_locators(self, locator, upto): if locator is not None: major_locator = locator elif upto is not None: major_locator = AutoDateLocator(minticks=2, maxticks=upto) else: major_locator = AutoDateLocator(minticks=2, maxticks=6) minor_locator = None return major_locator, minor_locator def _get_formatter(self, locator, formatter, concise): if formatter is not None: return formatter if concise: # TODO ideally we would have concise coordinate ticks, # but full semantic ticks. Is that possible? formatter = ConciseDateFormatter(locator) else: formatter = AutoDateFormatter(locator) return formatter # ----------------------------------------------------------------------------------- # # TODO Have this separate from Temporal or have Temporal(date=True) or similar? # class Calendric(Scale): # TODO Needed? Or handle this at layer (in stat or as param, eg binning=) # class Binned(Scale): # TODO any need for color-specific scales? # class Sequential(Continuous): # class Diverging(Continuous): # class Qualitative(Nominal): # ----------------------------------------------------------------------------------- # class PseudoAxis: """ Internal class implementing minimal interface equivalent to matplotlib Axis. Coordinate variables are typically scaled by attaching the Axis object from the figure where the plot will end up. Matplotlib has no similar concept of and axis for the other mappable variables (color, etc.), but to simplify the code, this object acts like an Axis and can be used to scale other variables. """ axis_name = "" # Matplotlib requirement but not actually used def __init__(self, scale): self.converter = None self.units = None self.scale = scale self.major = mpl.axis.Ticker() self.minor = mpl.axis.Ticker() # It appears that this needs to be initialized this way on matplotlib 3.1, # but not later versions. It is unclear whether there are any issues with it. self._data_interval = None, None scale.set_default_locators_and_formatters(self) # self.set_default_intervals() Is this ever needed? def set_view_interval(self, vmin, vmax): self._view_interval = vmin, vmax def get_view_interval(self): return self._view_interval # TODO do we want to distinguish view/data intervals? e.g. for a legend # we probably want to represent the full range of the data values, but # still norm the colormap. If so, we'll need to track data range separately # from the norm, which we currently don't do. def set_data_interval(self, vmin, vmax): self._data_interval = vmin, vmax def get_data_interval(self): return self._data_interval def get_tick_space(self): # TODO how to do this in a configurable / auto way? # Would be cool to have legend density adapt to figure size, etc. return 5 def set_major_locator(self, locator): self.major.locator = locator locator.set_axis(self) def set_major_formatter(self, formatter): self.major.formatter = formatter formatter.set_axis(self) def set_minor_locator(self, locator): self.minor.locator = locator locator.set_axis(self) def set_minor_formatter(self, formatter): self.minor.formatter = formatter formatter.set_axis(self) def set_units(self, units): self.units = units def update_units(self, x): """Pass units to the internal converter, potentially updating its mapping.""" self.converter = mpl.units.registry.get_converter(x) if self.converter is not None: self.converter.default_units(x, self) info = self.converter.axisinfo(self.units, self) if info is None: return if info.majloc is not None: self.set_major_locator(info.majloc) if info.majfmt is not None: self.set_major_formatter(info.majfmt) # This is in matplotlib method; do we need this? # self.set_default_intervals() def convert_units(self, x): """Return a numeric representation of the input data.""" if np.issubdtype(np.asarray(x).dtype, np.number): return x elif self.converter is None: return x return self.converter.convert(x, self.units, self) def get_scale(self): # Note that matplotlib actually returns a string here! # (e.g., with a log scale, axis.get_scale() returns "log") # Currently we just hit it with minor ticks where it checks for # scale == "log". I'm not sure how you'd actually use log-scale # minor "ticks" in a legend context, so this is fine.... return self.scale def get_majorticklocs(self): return self.major.locator() # ------------------------------------------------------------------------------------ # # Transform function creation def _make_identity_transforms() -> TransFuncs: def identity(x): return x return identity, identity def _make_logit_transforms(base: float | None = None) -> TransFuncs: log, exp = _make_log_transforms(base) def logit(x): with np.errstate(invalid="ignore", divide="ignore"): return log(x) - log(1 - x) def expit(x): with np.errstate(invalid="ignore", divide="ignore"): return exp(x) / (1 + exp(x)) return logit, expit def _make_log_transforms(base: float | None = None) -> TransFuncs: fs: TransFuncs if base is None: fs = np.log, np.exp elif base == 2: fs = np.log2, partial(np.power, 2) elif base == 10: fs = np.log10, partial(np.power, 10) else: def forward(x): return np.log(x) / np.log(base) fs = forward, partial(np.power, base) def log(x: ArrayLike) -> ArrayLike: with np.errstate(invalid="ignore", divide="ignore"): return fs[0](x) def exp(x: ArrayLike) -> ArrayLike: with np.errstate(invalid="ignore", divide="ignore"): return fs[1](x) return log, exp def _make_symlog_transforms(c: float = 1, base: float = 10) -> TransFuncs: # From https://iopscience.iop.org/article/10.1088/0957-0233/24/2/027001 # Note: currently not using base because we only get # one parameter from the string, and are using c (this is consistent with d3) log, exp = _make_log_transforms(base) def symlog(x): with np.errstate(invalid="ignore", divide="ignore"): return np.sign(x) * log(1 + np.abs(np.divide(x, c))) def symexp(x): with np.errstate(invalid="ignore", divide="ignore"): return np.sign(x) * c * (exp(np.abs(x)) - 1) return symlog, symexp def _make_sqrt_transforms() -> TransFuncs: def sqrt(x): return np.sign(x) * np.sqrt(np.abs(x)) def square(x): return np.sign(x) * np.square(x) return sqrt, square def _make_power_transforms(exp: float) -> TransFuncs: def forward(x): return np.sign(x) * np.power(np.abs(x), exp) def inverse(x): return np.sign(x) * np.power(np.abs(x), 1 / exp) return forward, inverse def _default_spacer(x: Series) -> float: return 1 ================================================ FILE: seaborn/_core/subplots.py ================================================ from __future__ import annotations from collections.abc import Generator import numpy as np import matplotlib as mpl import matplotlib.pyplot as plt from matplotlib.axes import Axes from matplotlib.figure import Figure from typing import TYPE_CHECKING if TYPE_CHECKING: # TODO move to seaborn._core.typing? from seaborn._core.plot import FacetSpec, PairSpec from matplotlib.figure import SubFigure class Subplots: """ Interface for creating and using matplotlib subplots based on seaborn parameters. Parameters ---------- subplot_spec : dict Keyword args for :meth:`matplotlib.figure.Figure.subplots`. facet_spec : dict Parameters that control subplot faceting. pair_spec : dict Parameters that control subplot pairing. data : PlotData Data used to define figure setup. """ def __init__( self, subplot_spec: dict, # TODO define as TypedDict facet_spec: FacetSpec, pair_spec: PairSpec, ): self.subplot_spec = subplot_spec self._check_dimension_uniqueness(facet_spec, pair_spec) self._determine_grid_dimensions(facet_spec, pair_spec) self._handle_wrapping(facet_spec, pair_spec) self._determine_axis_sharing(pair_spec) def _check_dimension_uniqueness( self, facet_spec: FacetSpec, pair_spec: PairSpec ) -> None: """Reject specs that pair and facet on (or wrap to) same figure dimension.""" err = None facet_vars = facet_spec.get("variables", {}) if facet_spec.get("wrap") and {"col", "row"} <= set(facet_vars): err = "Cannot wrap facets when specifying both `col` and `row`." elif ( pair_spec.get("wrap") and pair_spec.get("cross", True) and len(pair_spec.get("structure", {}).get("x", [])) > 1 and len(pair_spec.get("structure", {}).get("y", [])) > 1 ): err = "Cannot wrap subplots when pairing on both `x` and `y`." collisions = {"x": ["columns", "rows"], "y": ["rows", "columns"]} for pair_axis, (multi_dim, wrap_dim) in collisions.items(): if pair_axis not in pair_spec.get("structure", {}): continue elif multi_dim[:3] in facet_vars: err = f"Cannot facet the {multi_dim} while pairing on `{pair_axis}``." elif wrap_dim[:3] in facet_vars and facet_spec.get("wrap"): err = f"Cannot wrap the {wrap_dim} while pairing on `{pair_axis}``." elif wrap_dim[:3] in facet_vars and pair_spec.get("wrap"): err = f"Cannot wrap the {multi_dim} while faceting the {wrap_dim}." if err is not None: raise RuntimeError(err) # TODO what err class? Define PlotSpecError? def _determine_grid_dimensions( self, facet_spec: FacetSpec, pair_spec: PairSpec ) -> None: """Parse faceting and pairing information to define figure structure.""" self.grid_dimensions: dict[str, list] = {} for dim, axis in zip(["col", "row"], ["x", "y"]): facet_vars = facet_spec.get("variables", {}) if dim in facet_vars: self.grid_dimensions[dim] = facet_spec["structure"][dim] elif axis in pair_spec.get("structure", {}): self.grid_dimensions[dim] = [ None for _ in pair_spec.get("structure", {})[axis] ] else: self.grid_dimensions[dim] = [None] self.subplot_spec[f"n{dim}s"] = len(self.grid_dimensions[dim]) if not pair_spec.get("cross", True): self.subplot_spec["nrows"] = 1 self.n_subplots = self.subplot_spec["ncols"] * self.subplot_spec["nrows"] def _handle_wrapping( self, facet_spec: FacetSpec, pair_spec: PairSpec ) -> None: """Update figure structure parameters based on facet/pair wrapping.""" self.wrap = wrap = facet_spec.get("wrap") or pair_spec.get("wrap") if not wrap: return wrap_dim = "row" if self.subplot_spec["nrows"] > 1 else "col" flow_dim = {"row": "col", "col": "row"}[wrap_dim] n_subplots = self.subplot_spec[f"n{wrap_dim}s"] flow = int(np.ceil(n_subplots / wrap)) if wrap < self.subplot_spec[f"n{wrap_dim}s"]: self.subplot_spec[f"n{wrap_dim}s"] = wrap self.subplot_spec[f"n{flow_dim}s"] = flow self.n_subplots = n_subplots self.wrap_dim = wrap_dim def _determine_axis_sharing(self, pair_spec: PairSpec) -> None: """Update subplot spec with default or specified axis sharing parameters.""" axis_to_dim = {"x": "col", "y": "row"} key: str val: str | bool for axis in "xy": key = f"share{axis}" # Always use user-specified value, if present if key not in self.subplot_spec: if axis in pair_spec.get("structure", {}): # Paired axes are shared along one dimension by default if self.wrap is None and pair_spec.get("cross", True): val = axis_to_dim[axis] else: val = False else: # This will pick up faceted plots, as well as single subplot # figures, where the value doesn't really matter val = True self.subplot_spec[key] = val def init_figure( self, pair_spec: PairSpec, pyplot: bool = False, figure_kws: dict | None = None, target: Axes | Figure | SubFigure | None = None, ) -> Figure: """Initialize matplotlib objects and add seaborn-relevant metadata.""" # TODO reduce need to pass pair_spec here? if figure_kws is None: figure_kws = {} if isinstance(target, mpl.axes.Axes): if max(self.subplot_spec["nrows"], self.subplot_spec["ncols"]) > 1: err = " ".join([ "Cannot create multiple subplots after calling `Plot.on` with", f"a {mpl.axes.Axes} object.", f" You may want to use a {mpl.figure.SubFigure} instead.", ]) raise RuntimeError(err) self._subplot_list = [{ "ax": target, "left": True, "right": True, "top": True, "bottom": True, "col": None, "row": None, "x": "x", "y": "y", }] self._figure = target.figure return self._figure elif isinstance(target, mpl.figure.SubFigure): figure = target.figure elif isinstance(target, mpl.figure.Figure): figure = target else: if pyplot: figure = plt.figure(**figure_kws) else: figure = mpl.figure.Figure(**figure_kws) target = figure self._figure = figure axs = target.subplots(**self.subplot_spec, squeeze=False) if self.wrap: # Remove unused Axes and flatten the rest into a (2D) vector axs_flat = axs.ravel({"col": "C", "row": "F"}[self.wrap_dim]) axs, extra = np.split(axs_flat, [self.n_subplots]) for ax in extra: ax.remove() if self.wrap_dim == "col": axs = axs[np.newaxis, :] else: axs = axs[:, np.newaxis] # Get i, j coordinates for each Axes object # Note that i, j are with respect to faceting/pairing, # not the subplot grid itself, (which only matters in the case of wrapping). iter_axs: np.ndenumerate | zip if not pair_spec.get("cross", True): indices = np.arange(self.n_subplots) iter_axs = zip(zip(indices, indices), axs.flat) else: iter_axs = np.ndenumerate(axs) self._subplot_list = [] for (i, j), ax in iter_axs: info = {"ax": ax} nrows, ncols = self.subplot_spec["nrows"], self.subplot_spec["ncols"] if not self.wrap: info["left"] = j % ncols == 0 info["right"] = (j + 1) % ncols == 0 info["top"] = i == 0 info["bottom"] = i == nrows - 1 elif self.wrap_dim == "col": info["left"] = j % ncols == 0 info["right"] = ((j + 1) % ncols == 0) or ((j + 1) == self.n_subplots) info["top"] = j < ncols info["bottom"] = j >= (self.n_subplots - ncols) elif self.wrap_dim == "row": info["left"] = i < nrows info["right"] = i >= self.n_subplots - nrows info["top"] = i % nrows == 0 info["bottom"] = ((i + 1) % nrows == 0) or ((i + 1) == self.n_subplots) if not pair_spec.get("cross", True): info["top"] = j < ncols info["bottom"] = j >= self.n_subplots - ncols for dim in ["row", "col"]: idx = {"row": i, "col": j}[dim] info[dim] = self.grid_dimensions[dim][idx] for axis in "xy": idx = {"x": j, "y": i}[axis] if axis in pair_spec.get("structure", {}): key = f"{axis}{idx}" else: key = axis info[axis] = key self._subplot_list.append(info) return figure def __iter__(self) -> Generator[dict, None, None]: # TODO TypedDict? """Yield each subplot dictionary with Axes object and metadata.""" yield from self._subplot_list def __len__(self) -> int: """Return the number of subplots in this figure.""" return len(self._subplot_list) ================================================ FILE: seaborn/_core/typing.py ================================================ from __future__ import annotations from collections.abc import Iterable, Mapping from datetime import date, datetime, timedelta from typing import Any, Optional, Union, Tuple, List, Dict, Protocol from numpy import ndarray # TODO use ArrayLike? from pandas import DataFrame, Series, Index, Timestamp, Timedelta from matplotlib.colors import Colormap, Normalize ColumnName = Union[ str, bytes, date, datetime, timedelta, bool, complex, Timestamp, Timedelta ] Vector = Union[Series, Index, ndarray] VariableSpec = Union[ColumnName, Vector, None] VariableSpecList = Union[List[VariableSpec], Index, None] # A DataSource can be a DataFrame, an object that is convertible to a DataFrame, # or a Mapping, and is optional in all contexts where it is used. class DataFrameProtocol(Protocol): def to_pandas(self) -> DataFrame: ... DataSource = Union[DataFrame, DataFrameProtocol, Mapping, None] OrderSpec = Union[Iterable, None] # TODO technically str is iterable NormSpec = Union[Tuple[Optional[float], Optional[float]], Normalize, None] # TODO for discrete mappings, it would be ideal to use a parameterized type # as the dict values / list entries should be of specific type(s) for each method PaletteSpec = Union[str, list, dict, Colormap, None] DiscreteValueSpec = Union[dict, list, None] ContinuousValueSpec = Union[ Tuple[float, float], List[float], Dict[Any, float], None, ] class Default: def __repr__(self): return "" class Deprecated: def __repr__(self): return "" default = Default() deprecated = Deprecated() ================================================ FILE: seaborn/_docstrings.py ================================================ import re import pydoc from .external.docscrape import NumpyDocString class DocstringComponents: regexp = re.compile(r"\n((\n|.)+)\n\s*", re.MULTILINE) def __init__(self, comp_dict, strip_whitespace=True): """Read entries from a dict, optionally stripping outer whitespace.""" if strip_whitespace: entries = {} for key, val in comp_dict.items(): m = re.match(self.regexp, val) if m is None: entries[key] = val else: entries[key] = m.group(1) else: entries = comp_dict.copy() self.entries = entries def __getattr__(self, attr): """Provide dot access to entries for clean raw docstrings.""" if attr in self.entries: return self.entries[attr] else: try: return self.__getattribute__(attr) except AttributeError as err: # If Python is run with -OO, it will strip docstrings and our lookup # from self.entries will fail. We check for __debug__, which is actually # set to False by -O (it is True for normal execution). # But we only want to see an error when building the docs; # not something users should see, so this slight inconsistency is fine. if __debug__: raise err else: pass @classmethod def from_nested_components(cls, **kwargs): """Add multiple sub-sets of components.""" return cls(kwargs, strip_whitespace=False) @classmethod def from_function_params(cls, func): """Use the numpydoc parser to extract components from existing func.""" params = NumpyDocString(pydoc.getdoc(func))["Parameters"] comp_dict = {} for p in params: name = p.name type = p.type desc = "\n ".join(p.desc) comp_dict[name] = f"{name} : {type}\n {desc}" return cls(comp_dict) # TODO is "vector" the best term here? We mean to imply 1D data with a variety # of types? # TODO now that we can parse numpydoc style strings, do we need to define dicts # of docstring components, or just write out a docstring? _core_params = dict( data=""" data : :class:`pandas.DataFrame`, :class:`numpy.ndarray`, mapping, or sequence Input data structure. Either a long-form collection of vectors that can be assigned to named variables or a wide-form dataset that will be internally reshaped. """, # TODO add link to user guide narrative when exists xy=""" x, y : vectors or keys in ``data`` Variables that specify positions on the x and y axes. """, hue=""" hue : vector or key in ``data`` Semantic variable that is mapped to determine the color of plot elements. """, palette=""" palette : string, list, dict, or :class:`matplotlib.colors.Colormap` Method for choosing the colors to use when mapping the ``hue`` semantic. String values are passed to :func:`color_palette`. List or dict values imply categorical mapping, while a colormap object implies numeric mapping. """, # noqa: E501 hue_order=""" hue_order : vector of strings Specify the order of processing and plotting for categorical levels of the ``hue`` semantic. """, hue_norm=""" hue_norm : tuple or :class:`matplotlib.colors.Normalize` Either a pair of values that set the normalization range in data units or an object that will map from data units into a [0, 1] interval. Usage implies numeric mapping. """, color=""" color : :mod:`matplotlib color ` Single color specification for when hue mapping is not used. Otherwise, the plot will try to hook into the matplotlib property cycle. """, ax=""" ax : :class:`matplotlib.axes.Axes` Pre-existing axes for the plot. Otherwise, call :func:`matplotlib.pyplot.gca` internally. """, # noqa: E501 ) _core_returns = dict( ax=""" :class:`matplotlib.axes.Axes` The matplotlib axes containing the plot. """, facetgrid=""" :class:`FacetGrid` An object managing one or more subplots that correspond to conditional data subsets with convenient methods for batch-setting of axes attributes. """, jointgrid=""" :class:`JointGrid` An object managing multiple subplots that correspond to joint and marginal axes for plotting a bivariate relationship or distribution. """, pairgrid=""" :class:`PairGrid` An object managing multiple subplots that correspond to joint and marginal axes for pairwise combinations of multiple variables in a dataset. """, ) _seealso_blurbs = dict( # Relational plots scatterplot=""" scatterplot : Plot data using points. """, lineplot=""" lineplot : Plot data using lines. """, # Distribution plots displot=""" displot : Figure-level interface to distribution plot functions. """, histplot=""" histplot : Plot a histogram of binned counts with optional normalization or smoothing. """, kdeplot=""" kdeplot : Plot univariate or bivariate distributions using kernel density estimation. """, ecdfplot=""" ecdfplot : Plot empirical cumulative distribution functions. """, rugplot=""" rugplot : Plot a tick at each observation value along the x and/or y axes. """, # Categorical plots stripplot=""" stripplot : Plot a categorical scatter with jitter. """, swarmplot=""" swarmplot : Plot a categorical scatter with non-overlapping points. """, violinplot=""" violinplot : Draw an enhanced boxplot using kernel density estimation. """, pointplot=""" pointplot : Plot point estimates and CIs using markers and lines. """, # Multiples jointplot=""" jointplot : Draw a bivariate plot with univariate marginal distributions. """, pairplot=""" jointplot : Draw multiple bivariate plots with univariate marginal distributions. """, jointgrid=""" JointGrid : Set up a figure with joint and marginal views on bivariate data. """, pairgrid=""" PairGrid : Set up a figure with joint and marginal views on multiple variables. """, ) _core_docs = dict( params=DocstringComponents(_core_params), returns=DocstringComponents(_core_returns), seealso=DocstringComponents(_seealso_blurbs), ) ================================================ FILE: seaborn/_marks/__init__.py ================================================ ================================================ FILE: seaborn/_marks/area.py ================================================ from __future__ import annotations from collections import defaultdict from dataclasses import dataclass import numpy as np import matplotlib as mpl from seaborn._marks.base import ( Mark, Mappable, MappableBool, MappableFloat, MappableColor, MappableStyle, resolve_properties, resolve_color, document_properties, ) class AreaBase: def _plot(self, split_gen, scales, orient): patches = defaultdict(list) for keys, data, ax in split_gen(): kws = {} data = self._standardize_coordinate_parameters(data, orient) resolved = resolve_properties(self, keys, scales) verts = self._get_verts(data, orient) ax.update_datalim(verts) # TODO should really move this logic into resolve_color fc = resolve_color(self, keys, "", scales) if not resolved["fill"]: fc = mpl.colors.to_rgba(fc, 0) kws["facecolor"] = fc kws["edgecolor"] = resolve_color(self, keys, "edge", scales) kws["linewidth"] = resolved["edgewidth"] kws["linestyle"] = resolved["edgestyle"] patches[ax].append(mpl.patches.Polygon(verts, **kws)) for ax, ax_patches in patches.items(): for patch in ax_patches: self._postprocess_artist(patch, ax, orient) ax.add_patch(patch) def _standardize_coordinate_parameters(self, data, orient): return data def _postprocess_artist(self, artist, ax, orient): pass def _get_verts(self, data, orient): dv = {"x": "y", "y": "x"}[orient] data = data.sort_values(orient, kind="mergesort") verts = np.concatenate([ data[[orient, f"{dv}min"]].to_numpy(), data[[orient, f"{dv}max"]].to_numpy()[::-1], ]) if orient == "y": verts = verts[:, ::-1] return verts def _legend_artist(self, variables, value, scales): keys = {v: value for v in variables} resolved = resolve_properties(self, keys, scales) fc = resolve_color(self, keys, "", scales) if not resolved["fill"]: fc = mpl.colors.to_rgba(fc, 0) return mpl.patches.Patch( facecolor=fc, edgecolor=resolve_color(self, keys, "edge", scales), linewidth=resolved["edgewidth"], linestyle=resolved["edgestyle"], **self.artist_kws, ) @document_properties @dataclass class Area(AreaBase, Mark): """ A fill mark drawn from a baseline to data values. See also -------- Band : A fill mark representing an interval between values. Examples -------- .. include:: ../docstrings/objects.Area.rst """ color: MappableColor = Mappable("C0", ) alpha: MappableFloat = Mappable(.2, ) fill: MappableBool = Mappable(True, ) edgecolor: MappableColor = Mappable(depend="color") edgealpha: MappableFloat = Mappable(1, ) edgewidth: MappableFloat = Mappable(rc="patch.linewidth", ) edgestyle: MappableStyle = Mappable("-", ) # TODO should this be settable / mappable? baseline: MappableFloat = Mappable(0, grouping=False) def _standardize_coordinate_parameters(self, data, orient): dv = {"x": "y", "y": "x"}[orient] return data.rename(columns={"baseline": f"{dv}min", dv: f"{dv}max"}) def _postprocess_artist(self, artist, ax, orient): # TODO copying a lot of code from Bar, let's abstract this # See comments there, I am not going to repeat them too artist.set_linewidth(artist.get_linewidth() * 2) linestyle = artist.get_linestyle() if linestyle[1]: linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1])) artist.set_linestyle(linestyle) artist.set_clip_path(artist.get_path(), artist.get_transform() + ax.transData) if self.artist_kws.get("clip_on", True): artist.set_clip_box(ax.bbox) val_idx = ["y", "x"].index(orient) artist.sticky_edges[val_idx][:] = (0, np.inf) @document_properties @dataclass class Band(AreaBase, Mark): """ A fill mark representing an interval between values. See also -------- Area : A fill mark drawn from a baseline to data values. Examples -------- .. include:: ../docstrings/objects.Band.rst """ color: MappableColor = Mappable("C0", ) alpha: MappableFloat = Mappable(.2, ) fill: MappableBool = Mappable(True, ) edgecolor: MappableColor = Mappable(depend="color", ) edgealpha: MappableFloat = Mappable(1, ) edgewidth: MappableFloat = Mappable(0, ) edgestyle: MappableFloat = Mappable("-", ) def _standardize_coordinate_parameters(self, data, orient): # dv = {"x": "y", "y": "x"}[orient] # TODO assert that all(ymax >= ymin)? # TODO what if only one exist? other = {"x": "y", "y": "x"}[orient] if not set(data.columns) & {f"{other}min", f"{other}max"}: agg = {f"{other}min": (other, "min"), f"{other}max": (other, "max")} data = data.groupby(orient).agg(**agg).reset_index() return data ================================================ FILE: seaborn/_marks/bar.py ================================================ from __future__ import annotations from collections import defaultdict from dataclasses import dataclass import numpy as np import matplotlib as mpl from seaborn._marks.base import ( Mark, Mappable, MappableBool, MappableColor, MappableFloat, MappableStyle, resolve_properties, resolve_color, document_properties ) from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any from matplotlib.artist import Artist from seaborn._core.scales import Scale class BarBase(Mark): def _make_patches(self, data, scales, orient): transform = scales[orient]._matplotlib_scale.get_transform() forward = transform.transform reverse = transform.inverted().transform other = {"x": "y", "y": "x"}[orient] pos = reverse(forward(data[orient]) - data["width"] / 2) width = reverse(forward(data[orient]) + data["width"] / 2) - pos val = (data[other] - data["baseline"]).to_numpy() base = data["baseline"].to_numpy() kws = self._resolve_properties(data, scales) if orient == "x": kws.update(x=pos, y=base, w=width, h=val) else: kws.update(x=base, y=pos, w=val, h=width) kws.pop("width", None) kws.pop("baseline", None) val_dim = {"x": "h", "y": "w"}[orient] bars, vals = [], [] for i in range(len(data)): row = {k: v[i] for k, v in kws.items()} # Skip bars with no value. It's possible we'll want to make this # an option (i.e so you have an artist for animating or annotating), # but let's keep things simple for now. if not np.nan_to_num(row[val_dim]): continue bar = mpl.patches.Rectangle( xy=(row["x"], row["y"]), width=row["w"], height=row["h"], facecolor=row["facecolor"], edgecolor=row["edgecolor"], linestyle=row["edgestyle"], linewidth=row["edgewidth"], **self.artist_kws, ) bars.append(bar) vals.append(row[val_dim]) return bars, vals def _resolve_properties(self, data, scales): resolved = resolve_properties(self, data, scales) resolved["facecolor"] = resolve_color(self, data, "", scales) resolved["edgecolor"] = resolve_color(self, data, "edge", scales) fc = resolved["facecolor"] if isinstance(fc, tuple): resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] else: fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? resolved["facecolor"] = fc return resolved def _legend_artist( self, variables: list[str], value: Any, scales: dict[str, Scale], ) -> Artist: # TODO return some sensible default? key = {v: value for v in variables} key = self._resolve_properties(key, scales) artist = mpl.patches.Patch( facecolor=key["facecolor"], edgecolor=key["edgecolor"], linewidth=key["edgewidth"], linestyle=key["edgestyle"], ) return artist @document_properties @dataclass class Bar(BarBase): """ A bar mark drawn between baseline and data values. See also -------- Bars : A faster bar mark with defaults more suitable for histograms. Examples -------- .. include:: ../docstrings/objects.Bar.rst """ color: MappableColor = Mappable("C0", grouping=False) alpha: MappableFloat = Mappable(.7, grouping=False) fill: MappableBool = Mappable(True, grouping=False) edgecolor: MappableColor = Mappable(depend="color", grouping=False) edgealpha: MappableFloat = Mappable(1, grouping=False) edgewidth: MappableFloat = Mappable(rc="patch.linewidth", grouping=False) edgestyle: MappableStyle = Mappable("-", grouping=False) # pattern: MappableString = Mappable(None) # TODO no Property yet width: MappableFloat = Mappable(.8, grouping=False) baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable? def _plot(self, split_gen, scales, orient): val_idx = ["y", "x"].index(orient) for _, data, ax in split_gen(): bars, vals = self._make_patches(data, scales, orient) for bar in bars: # Because we are clipping the artist (see below), the edges end up # looking half as wide as they actually are. I don't love this clumsy # workaround, which is going to cause surprises if you work with the # artists directly. We may need to revisit after feedback. bar.set_linewidth(bar.get_linewidth() * 2) linestyle = bar.get_linestyle() if linestyle[1]: linestyle = (linestyle[0], tuple(x / 2 for x in linestyle[1])) bar.set_linestyle(linestyle) # This is a bit of a hack to handle the fact that the edge lines are # centered on the actual extents of the bar, and overlap when bars are # stacked or dodged. We may discover that this causes problems and needs # to be revisited at some point. Also it should be faster to clip with # a bbox than a path, but I cant't work out how to get the intersection # with the axes bbox. bar.set_clip_path(bar.get_path(), bar.get_transform() + ax.transData) if self.artist_kws.get("clip_on", True): # It seems the above hack undoes the default axes clipping bar.set_clip_box(ax.bbox) bar.sticky_edges[val_idx][:] = (0, np.inf) ax.add_patch(bar) # Add a container which is useful for, e.g. Axes.bar_label orientation = {"x": "vertical", "y": "horizontal"}[orient] container_kws = dict(datavalues=vals, orientation=orientation) container = mpl.container.BarContainer(bars, **container_kws) ax.add_container(container) @document_properties @dataclass class Bars(BarBase): """ A faster bar mark with defaults more suitable for histograms. See also -------- Bar : A bar mark drawn between baseline and data values. Examples -------- .. include:: ../docstrings/objects.Bars.rst """ color: MappableColor = Mappable("C0", grouping=False) alpha: MappableFloat = Mappable(.7, grouping=False) fill: MappableBool = Mappable(True, grouping=False) edgecolor: MappableColor = Mappable(rc="patch.edgecolor", grouping=False) edgealpha: MappableFloat = Mappable(1, grouping=False) edgewidth: MappableFloat = Mappable(auto=True, grouping=False) edgestyle: MappableStyle = Mappable("-", grouping=False) # pattern: MappableString = Mappable(None) # TODO no Property yet width: MappableFloat = Mappable(1, grouping=False) baseline: MappableFloat = Mappable(0, grouping=False) # TODO *is* this mappable? def _plot(self, split_gen, scales, orient): ori_idx = ["x", "y"].index(orient) val_idx = ["y", "x"].index(orient) patches = defaultdict(list) for _, data, ax in split_gen(): bars, _ = self._make_patches(data, scales, orient) patches[ax].extend(bars) collections = {} for ax, ax_patches in patches.items(): col = mpl.collections.PatchCollection(ax_patches, match_original=True) col.sticky_edges[val_idx][:] = (0, np.inf) ax.add_collection(col, autolim=False) collections[ax] = col # Workaround for matplotlib autoscaling bug # https://github.com/matplotlib/matplotlib/issues/11898 # https://github.com/matplotlib/matplotlib/issues/23129 xys = np.vstack([path.vertices for path in col.get_paths()]) ax.update_datalim(xys) if "edgewidth" not in scales and isinstance(self.edgewidth, Mappable): for ax in collections: ax.autoscale_view() def get_dimensions(collection): edges, widths = [], [] for verts in (path.vertices for path in collection.get_paths()): edges.append(min(verts[:, ori_idx])) widths.append(np.ptp(verts[:, ori_idx])) return np.array(edges), np.array(widths) min_width = np.inf for ax, col in collections.items(): edges, widths = get_dimensions(col) points = 72 / ax.figure.dpi * abs( ax.transData.transform([edges + widths] * 2) - ax.transData.transform([edges] * 2) ) min_width = min(min_width, min(points[:, ori_idx])) linewidth = min(.1 * min_width, mpl.rcParams["patch.linewidth"]) for _, col in collections.items(): col.set_linewidth(linewidth) ================================================ FILE: seaborn/_marks/base.py ================================================ from __future__ import annotations from dataclasses import dataclass, fields, field import textwrap from typing import Any, Callable, Union from collections.abc import Generator import numpy as np import pandas as pd import matplotlib as mpl from numpy import ndarray from pandas import DataFrame from matplotlib.artist import Artist from seaborn._core.scales import Scale from seaborn._core.properties import ( PROPERTIES, Property, RGBATuple, DashPattern, DashPatternWithOffset, ) from seaborn._core.exceptions import PlotSpecError class Mappable: def __init__( self, val: Any = None, depend: str | None = None, rc: str | None = None, auto: bool = False, grouping: bool = True, ): """ Property that can be mapped from data or set directly, with flexible defaults. Parameters ---------- val : Any Use this value as the default. depend : str Use the value of this feature as the default. rc : str Use the value of this rcParam as the default. auto : bool The default value will depend on other parameters at compile time. grouping : bool If True, use the mapped variable to define groups. """ if depend is not None: assert depend in PROPERTIES if rc is not None: assert rc in mpl.rcParams self._val = val self._rc = rc self._depend = depend self._auto = auto self._grouping = grouping def __repr__(self): """Nice formatting for when object appears in Mark init signature.""" if self._val is not None: s = f"<{repr(self._val)}>" elif self._depend is not None: s = f"" elif self._rc is not None: s = f"" elif self._auto: s = "" else: s = "" return s @property def depend(self) -> Any: """Return the name of the feature to source a default value from.""" return self._depend @property def grouping(self) -> bool: return self._grouping @property def default(self) -> Any: """Get the default value for this feature, or access the relevant rcParam.""" if self._val is not None: return self._val elif self._rc is not None: return mpl.rcParams.get(self._rc) # TODO where is the right place to put this kind of type aliasing? MappableBool = Union[bool, Mappable] MappableString = Union[str, Mappable] MappableFloat = Union[float, Mappable] MappableColor = Union[str, tuple, Mappable] MappableStyle = Union[str, DashPattern, DashPatternWithOffset, Mappable] @dataclass class Mark: """Base class for objects that visually represent data.""" artist_kws: dict = field(default_factory=dict) @property def _mappable_props(self): return { f.name: getattr(self, f.name) for f in fields(self) if isinstance(f.default, Mappable) } @property def _grouping_props(self): # TODO does it make sense to have variation within a Mark's # properties about whether they are grouping? return [ f.name for f in fields(self) if isinstance(f.default, Mappable) and f.default.grouping ] # TODO make this method private? Would extender every need to call directly? def _resolve( self, data: DataFrame | dict[str, Any], name: str, scales: dict[str, Scale] | None = None, ) -> Any: """Obtain default, specified, or mapped value for a named feature. Parameters ---------- data : DataFrame or dict with scalar values Container with data values for features that will be semantically mapped. name : string Identity of the feature / semantic. scales: dict Mapping from variable to corresponding scale object. Returns ------- value or array of values Outer return type depends on whether `data` is a dict (implying that we want a single value) or DataFrame (implying that we want an array of values with matching length). """ feature = self._mappable_props[name] prop = PROPERTIES.get(name, Property(name)) directly_specified = not isinstance(feature, Mappable) return_multiple = isinstance(data, pd.DataFrame) return_array = return_multiple and not name.endswith("style") # Special case width because it needs to be resolved and added to the dataframe # during layer prep (so the Move operations use it properly). # TODO how does width *scaling* work, e.g. for violin width by count? if name == "width": directly_specified = directly_specified and name not in data if directly_specified: feature = prop.standardize(feature) if return_multiple: feature = [feature] * len(data) if return_array: feature = np.array(feature) return feature if name in data: if scales is None or name not in scales: # TODO Might this obviate the identity scale? Just don't add a scale? feature = data[name] else: scale = scales[name] value = data[name] try: feature = scale(value) except Exception as err: raise PlotSpecError._during("Scaling operation", name) from err if return_array: feature = np.asarray(feature) return feature if feature.depend is not None: # TODO add source_func or similar to transform the source value? # e.g. set linewidth as a proportion of pointsize? return self._resolve(data, feature.depend, scales) default = prop.standardize(feature.default) if return_multiple: default = [default] * len(data) if return_array: default = np.array(default) return default def _infer_orient(self, scales: dict) -> str: # TODO type scales # TODO The original version of this (in seaborn._base) did more checking. # Paring that down here for the prototype to see what restrictions make sense. # TODO rethink this to map from scale type to "DV priority" and use that? # e.g. Nominal > Discrete > Continuous x = 0 if "x" not in scales else scales["x"]._priority y = 0 if "y" not in scales else scales["y"]._priority if y > x: return "y" else: return "x" def _plot( self, split_generator: Callable[[], Generator], scales: dict[str, Scale], orient: str, ) -> None: """Main interface for creating a plot.""" raise NotImplementedError() def _legend_artist( self, variables: list[str], value: Any, scales: dict[str, Scale], ) -> Artist | None: return None def resolve_properties( mark: Mark, data: DataFrame, scales: dict[str, Scale] ) -> dict[str, Any]: props = { name: mark._resolve(data, name, scales) for name in mark._mappable_props } return props def resolve_color( mark: Mark, data: DataFrame | dict, prefix: str = "", scales: dict[str, Scale] | None = None, ) -> RGBATuple | ndarray: """ Obtain a default, specified, or mapped value for a color feature. This method exists separately to support the relationship between a color and its corresponding alpha. We want to respect alpha values that are passed in specified (or mapped) color values but also make use of a separate `alpha` variable, which can be mapped. This approach may also be extended to support mapping of specific color channels (i.e. luminance, chroma) in the future. Parameters ---------- mark : Mark with the color property. data : Container with data values for features that will be semantically mapped. prefix : Support "color", "fillcolor", etc. """ color = mark._resolve(data, f"{prefix}color", scales) if f"{prefix}alpha" in mark._mappable_props: alpha = mark._resolve(data, f"{prefix}alpha", scales) else: alpha = mark._resolve(data, "alpha", scales) def visible(x, axis=None): """Detect "invisible" colors to set alpha appropriately.""" # TODO First clause only needed to handle non-rgba arrays, # which we are trying to handle upstream return np.array(x).dtype.kind != "f" or np.isfinite(x).all(axis) # Second check here catches vectors of strings with identity scale # It could probably be handled better upstream. This is a tricky problem if np.ndim(color) < 2 and all(isinstance(x, float) for x in color): if len(color) == 4: return mpl.colors.to_rgba(color) alpha = alpha if visible(color) else np.nan return mpl.colors.to_rgba(color, alpha) else: if np.ndim(color) == 2 and color.shape[1] == 4: return mpl.colors.to_rgba_array(color) alpha = np.where(visible(color, axis=1), alpha, np.nan) return mpl.colors.to_rgba_array(color, alpha) # TODO should we be implementing fill here too? # (i.e. set fillalpha to 0 when fill=False) def document_properties(mark): properties = [f.name for f in fields(mark) if isinstance(f.default, Mappable)] text = [ "", " This mark defines the following properties:", textwrap.fill( ", ".join([f"|{p}|" for p in properties]), width=78, initial_indent=" " * 8, subsequent_indent=" " * 8, ), ] docstring_lines = mark.__doc__.split("\n") new_docstring = "\n".join([ *docstring_lines[:2], *text, *docstring_lines[2:], ]) mark.__doc__ = new_docstring return mark ================================================ FILE: seaborn/_marks/dot.py ================================================ from __future__ import annotations from dataclasses import dataclass import numpy as np import matplotlib as mpl from seaborn._marks.base import ( Mark, Mappable, MappableBool, MappableFloat, MappableString, MappableColor, MappableStyle, resolve_properties, resolve_color, document_properties, ) from typing import TYPE_CHECKING if TYPE_CHECKING: from typing import Any from matplotlib.artist import Artist from seaborn._core.scales import Scale class DotBase(Mark): def _resolve_paths(self, data): paths = [] path_cache = {} marker = data["marker"] def get_transformed_path(m): return m.get_path().transformed(m.get_transform()) if isinstance(marker, mpl.markers.MarkerStyle): return get_transformed_path(marker) for m in marker: if m not in path_cache: path_cache[m] = get_transformed_path(m) paths.append(path_cache[m]) return paths def _resolve_properties(self, data, scales): resolved = resolve_properties(self, data, scales) resolved["path"] = self._resolve_paths(resolved) resolved["size"] = resolved["pointsize"] ** 2 if isinstance(data, dict): # Properties for single dot filled_marker = resolved["marker"].is_filled() else: filled_marker = [m.is_filled() for m in resolved["marker"]] resolved["fill"] = resolved["fill"] * filled_marker return resolved def _plot(self, split_gen, scales, orient): # TODO Not backcompat with allowed (but nonfunctional) univariate plots # (That should be solved upstream by defaulting to "" for unset x/y?) # (Be mindful of xmin/xmax, etc!) for _, data, ax in split_gen(): offsets = np.column_stack([data["x"], data["y"]]) data = self._resolve_properties(data, scales) points = mpl.collections.PathCollection( offsets=offsets, paths=data["path"], sizes=data["size"], facecolors=data["facecolor"], edgecolors=data["edgecolor"], linewidths=data["linewidth"], linestyles=data["edgestyle"], transOffset=ax.transData, transform=mpl.transforms.IdentityTransform(), **self.artist_kws, ) ax.add_collection(points) def _legend_artist( self, variables: list[str], value: Any, scales: dict[str, Scale], ) -> Artist: key = {v: value for v in variables} res = self._resolve_properties(key, scales) return mpl.collections.PathCollection( paths=[res["path"]], sizes=[res["size"]], facecolors=[res["facecolor"]], edgecolors=[res["edgecolor"]], linewidths=[res["linewidth"]], linestyles=[res["edgestyle"]], transform=mpl.transforms.IdentityTransform(), **self.artist_kws, ) @document_properties @dataclass class Dot(DotBase): """ A mark suitable for dot plots or less-dense scatterplots. See also -------- Dots : A dot mark defined by strokes to better handle overplotting. Examples -------- .. include:: ../docstrings/objects.Dot.rst """ marker: MappableString = Mappable("o", grouping=False) pointsize: MappableFloat = Mappable(6, grouping=False) # TODO rcParam? stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam? color: MappableColor = Mappable("C0", grouping=False) alpha: MappableFloat = Mappable(1, grouping=False) fill: MappableBool = Mappable(True, grouping=False) edgecolor: MappableColor = Mappable(depend="color", grouping=False) edgealpha: MappableFloat = Mappable(depend="alpha", grouping=False) edgewidth: MappableFloat = Mappable(.5, grouping=False) # TODO rcParam? edgestyle: MappableStyle = Mappable("-", grouping=False) def _resolve_properties(self, data, scales): resolved = super()._resolve_properties(data, scales) filled = resolved["fill"] main_stroke = resolved["stroke"] edge_stroke = resolved["edgewidth"] resolved["linewidth"] = np.where(filled, edge_stroke, main_stroke) main_color = resolve_color(self, data, "", scales) edge_color = resolve_color(self, data, "edge", scales) if not np.isscalar(filled): # Expand dims to use in np.where with rgba arrays filled = filled[:, None] resolved["edgecolor"] = np.where(filled, edge_color, main_color) filled = np.squeeze(filled) if isinstance(main_color, tuple): # TODO handle this in resolve_color main_color = tuple([*main_color[:3], main_color[3] * filled]) else: main_color = np.c_[main_color[:, :3], main_color[:, 3] * filled] resolved["facecolor"] = main_color return resolved @document_properties @dataclass class Dots(DotBase): """ A dot mark defined by strokes to better handle overplotting. See also -------- Dot : A mark suitable for dot plots or less-dense scatterplots. Examples -------- .. include:: ../docstrings/objects.Dots.rst """ # TODO retype marker as MappableMarker marker: MappableString = Mappable(rc="scatter.marker", grouping=False) pointsize: MappableFloat = Mappable(4, grouping=False) # TODO rcParam? stroke: MappableFloat = Mappable(.75, grouping=False) # TODO rcParam? color: MappableColor = Mappable("C0", grouping=False) alpha: MappableFloat = Mappable(1, grouping=False) # TODO auto alpha? fill: MappableBool = Mappable(True, grouping=False) fillcolor: MappableColor = Mappable(depend="color", grouping=False) fillalpha: MappableFloat = Mappable(.2, grouping=False) def _resolve_properties(self, data, scales): resolved = super()._resolve_properties(data, scales) resolved["linewidth"] = resolved.pop("stroke") resolved["facecolor"] = resolve_color(self, data, "fill", scales) resolved["edgecolor"] = resolve_color(self, data, "", scales) resolved.setdefault("edgestyle", (0, None)) fc = resolved["facecolor"] if isinstance(fc, tuple): resolved["facecolor"] = fc[0], fc[1], fc[2], fc[3] * resolved["fill"] else: fc[:, 3] = fc[:, 3] * resolved["fill"] # TODO Is inplace mod a problem? resolved["facecolor"] = fc return resolved ================================================ FILE: seaborn/_marks/line.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar import numpy as np import matplotlib as mpl from seaborn._marks.base import ( Mark, Mappable, MappableFloat, MappableString, MappableColor, resolve_properties, resolve_color, document_properties, ) @document_properties @dataclass class Path(Mark): """ A mark connecting data points in the order they appear. See also -------- Line : A mark connecting data points with sorting along the orientation axis. Paths : A faster but less-flexible mark for drawing many paths. Examples -------- .. include:: ../docstrings/objects.Path.rst """ color: MappableColor = Mappable("C0") alpha: MappableFloat = Mappable(1) linewidth: MappableFloat = Mappable(rc="lines.linewidth") linestyle: MappableString = Mappable(rc="lines.linestyle") marker: MappableString = Mappable(rc="lines.marker") pointsize: MappableFloat = Mappable(rc="lines.markersize") fillcolor: MappableColor = Mappable(depend="color") edgecolor: MappableColor = Mappable(depend="color") edgewidth: MappableFloat = Mappable(rc="lines.markeredgewidth") _sort: ClassVar[bool] = False def _plot(self, split_gen, scales, orient): for keys, data, ax in split_gen(keep_na=not self._sort): vals = resolve_properties(self, keys, scales) vals["color"] = resolve_color(self, keys, scales=scales) vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales) vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales) if self._sort: data = data.sort_values(orient, kind="mergesort") artist_kws = self.artist_kws.copy() self._handle_capstyle(artist_kws, vals) line = mpl.lines.Line2D( data["x"].to_numpy(), data["y"].to_numpy(), color=vals["color"], linewidth=vals["linewidth"], linestyle=vals["linestyle"], marker=vals["marker"], markersize=vals["pointsize"], markerfacecolor=vals["fillcolor"], markeredgecolor=vals["edgecolor"], markeredgewidth=vals["edgewidth"], **artist_kws, ) ax.add_line(line) def _legend_artist(self, variables, value, scales): keys = {v: value for v in variables} vals = resolve_properties(self, keys, scales) vals["color"] = resolve_color(self, keys, scales=scales) vals["fillcolor"] = resolve_color(self, keys, prefix="fill", scales=scales) vals["edgecolor"] = resolve_color(self, keys, prefix="edge", scales=scales) artist_kws = self.artist_kws.copy() self._handle_capstyle(artist_kws, vals) return mpl.lines.Line2D( [], [], color=vals["color"], linewidth=vals["linewidth"], linestyle=vals["linestyle"], marker=vals["marker"], markersize=vals["pointsize"], markerfacecolor=vals["fillcolor"], markeredgecolor=vals["edgecolor"], markeredgewidth=vals["edgewidth"], **artist_kws, ) def _handle_capstyle(self, kws, vals): # Work around for this matplotlib issue: # https://github.com/matplotlib/matplotlib/issues/23437 if vals["linestyle"][1] is None: capstyle = kws.get("solid_capstyle", mpl.rcParams["lines.solid_capstyle"]) kws["dash_capstyle"] = capstyle @document_properties @dataclass class Line(Path): """ A mark connecting data points with sorting along the orientation axis. See also -------- Path : A mark connecting data points in the order they appear. Lines : A faster but less-flexible mark for drawing many lines. Examples -------- .. include:: ../docstrings/objects.Line.rst """ _sort: ClassVar[bool] = True @document_properties @dataclass class Paths(Mark): """ A faster but less-flexible mark for drawing many paths. See also -------- Path : A mark connecting data points in the order they appear. Examples -------- .. include:: ../docstrings/objects.Paths.rst """ color: MappableColor = Mappable("C0") alpha: MappableFloat = Mappable(1) linewidth: MappableFloat = Mappable(rc="lines.linewidth") linestyle: MappableString = Mappable(rc="lines.linestyle") _sort: ClassVar[bool] = False def __post_init__(self): # LineCollection artists have a capstyle property but don't source its value # from the rc, so we do that manually here. Unfortunately, because we add # only one LineCollection, we have the use the same capstyle for all lines # even when they are dashed. It's a slight inconsistency, but looks fine IMO. self.artist_kws.setdefault("capstyle", mpl.rcParams["lines.solid_capstyle"]) def _plot(self, split_gen, scales, orient): line_data = {} for keys, data, ax in split_gen(keep_na=not self._sort): if ax not in line_data: line_data[ax] = { "segments": [], "colors": [], "linewidths": [], "linestyles": [], } segments = self._setup_segments(data, orient) line_data[ax]["segments"].extend(segments) n = len(segments) vals = resolve_properties(self, keys, scales) vals["color"] = resolve_color(self, keys, scales=scales) line_data[ax]["colors"].extend([vals["color"]] * n) line_data[ax]["linewidths"].extend([vals["linewidth"]] * n) line_data[ax]["linestyles"].extend([vals["linestyle"]] * n) for ax, ax_data in line_data.items(): lines = mpl.collections.LineCollection(**ax_data, **self.artist_kws) # Handle datalim update manually # https://github.com/matplotlib/matplotlib/issues/23129 ax.add_collection(lines, autolim=False) if ax_data["segments"]: xy = np.concatenate(ax_data["segments"]) ax.update_datalim(xy) def _legend_artist(self, variables, value, scales): key = resolve_properties(self, {v: value for v in variables}, scales) artist_kws = self.artist_kws.copy() capstyle = artist_kws.pop("capstyle") artist_kws["solid_capstyle"] = capstyle artist_kws["dash_capstyle"] = capstyle return mpl.lines.Line2D( [], [], color=key["color"], linewidth=key["linewidth"], linestyle=key["linestyle"], **artist_kws, ) def _setup_segments(self, data, orient): if self._sort: data = data.sort_values(orient, kind="mergesort") # Column stack to avoid block consolidation xy = np.column_stack([data["x"], data["y"]]) return [xy] @document_properties @dataclass class Lines(Paths): """ A faster but less-flexible mark for drawing many lines. See also -------- Line : A mark connecting data points with sorting along the orientation axis. Examples -------- .. include:: ../docstrings/objects.Lines.rst """ _sort: ClassVar[bool] = True @document_properties @dataclass class Range(Paths): """ An oriented line mark drawn between min/max values. Examples -------- .. include:: ../docstrings/objects.Range.rst """ def _setup_segments(self, data, orient): # TODO better checks on what variables we have # TODO what if only one exist? val = {"x": "y", "y": "x"}[orient] if not set(data.columns) & {f"{val}min", f"{val}max"}: agg = {f"{val}min": (val, "min"), f"{val}max": (val, "max")} data = data.groupby(orient).agg(**agg).reset_index() cols = [orient, f"{val}min", f"{val}max"] data = data[cols].melt(orient, value_name=val)[["x", "y"]] segments = [d.to_numpy() for _, d in data.groupby(orient)] return segments @document_properties @dataclass class Dash(Paths): """ A line mark drawn as an oriented segment for each datapoint. Examples -------- .. include:: ../docstrings/objects.Dash.rst """ width: MappableFloat = Mappable(.8, grouping=False) def _setup_segments(self, data, orient): ori = ["x", "y"].index(orient) xys = data[["x", "y"]].to_numpy().astype(float) segments = np.stack([xys, xys], axis=1) segments[:, 0, ori] -= data["width"] / 2 segments[:, 1, ori] += data["width"] / 2 return segments ================================================ FILE: seaborn/_marks/text.py ================================================ from __future__ import annotations from collections import defaultdict from dataclasses import dataclass import numpy as np import matplotlib as mpl from matplotlib.transforms import ScaledTranslation from seaborn._marks.base import ( Mark, Mappable, MappableFloat, MappableString, MappableColor, resolve_properties, resolve_color, document_properties, ) @document_properties @dataclass class Text(Mark): """ A textual mark to annotate or represent data values. Examples -------- .. include:: ../docstrings/objects.Text.rst """ text: MappableString = Mappable("") color: MappableColor = Mappable("k") alpha: MappableFloat = Mappable(1) fontsize: MappableFloat = Mappable(rc="font.size") halign: MappableString = Mappable("center") valign: MappableString = Mappable("center_baseline") offset: MappableFloat = Mappable(4) def _plot(self, split_gen, scales, orient): ax_data = defaultdict(list) for keys, data, ax in split_gen(): vals = resolve_properties(self, keys, scales) color = resolve_color(self, keys, "", scales) halign = vals["halign"] valign = vals["valign"] fontsize = vals["fontsize"] offset = vals["offset"] / 72 offset_trans = ScaledTranslation( {"right": -offset, "left": +offset}.get(halign, 0), {"top": -offset, "bottom": +offset, "baseline": +offset}.get(valign, 0), ax.figure.dpi_scale_trans, ) for row in data.to_dict("records"): artist = mpl.text.Text( x=row["x"], y=row["y"], text=str(row.get("text", vals["text"])), color=color, fontsize=fontsize, horizontalalignment=halign, verticalalignment=valign, transform=ax.transData + offset_trans, **self.artist_kws, ) ax.add_artist(artist) ax_data[ax].append([row["x"], row["y"]]) for ax, ax_vals in ax_data.items(): ax.update_datalim(np.array(ax_vals)) ================================================ FILE: seaborn/_statistics.py ================================================ """Statistical transformations for visualization. This module is currently private, but is being written to eventually form part of the public API. The classes should behave roughly in the style of scikit-learn. - All data-independent parameters should be passed to the class constructor. - Each class should implement a default transformation that is exposed through __call__. These are currently written for vector arguments, but I think consuming a whole `plot_data` DataFrame and return it with transformed variables would make more sense. - Some class have data-dependent preprocessing that should be cached and used multiple times (think defining histogram bins off all data and then counting observations within each bin multiple times per data subsets). These currently have unique names, but it would be good to have a common name. Not quite `fit`, but something similar. - Alternatively, the transform interface could take some information about grouping variables and do a groupby internally. - Some classes should define alternate transforms that might make the most sense with a different function. For example, KDE usually evaluates the distribution on a regular grid, but it would be useful for it to transform at the actual datapoints. Then again, this could be controlled by a parameter at the time of class instantiation. """ from numbers import Number from statistics import NormalDist import numpy as np import pandas as pd try: from scipy.stats import gaussian_kde _no_scipy = False except ImportError: from .external.kde import gaussian_kde _no_scipy = True from .algorithms import bootstrap from .utils import _check_argument class KDE: """Univariate and bivariate kernel density estimator.""" def __init__( self, *, bw_method=None, bw_adjust=1, gridsize=200, cut=3, clip=None, cumulative=False, ): """Initialize the estimator with its parameters. Parameters ---------- bw_method : string, scalar, or callable, optional Method for determining the smoothing bandwidth to use; passed to :class:`scipy.stats.gaussian_kde`. bw_adjust : number, optional Factor that multiplicatively scales the value chosen using ``bw_method``. Increasing will make the curve smoother. See Notes. gridsize : int, optional Number of points on each dimension of the evaluation grid. cut : number, optional Factor, multiplied by the smoothing bandwidth, that determines how far the evaluation grid extends past the extreme datapoints. When set to 0, truncate the curve at the data limits. clip : pair of numbers or None, or a pair of such pairs Do not evaluate the density outside of these limits. cumulative : bool, optional If True, estimate a cumulative distribution function. Requires scipy. """ if clip is None: clip = None, None self.bw_method = bw_method self.bw_adjust = bw_adjust self.gridsize = gridsize self.cut = cut self.clip = clip self.cumulative = cumulative if cumulative and _no_scipy: raise RuntimeError("Cumulative KDE evaluation requires scipy") self.support = None def _define_support_grid(self, x, bw, cut, clip, gridsize): """Create the grid of evaluation points depending for vector x.""" clip_lo = -np.inf if clip[0] is None else clip[0] clip_hi = +np.inf if clip[1] is None else clip[1] gridmin = max(x.min() - bw * cut, clip_lo) gridmax = min(x.max() + bw * cut, clip_hi) return np.linspace(gridmin, gridmax, gridsize) def _define_support_univariate(self, x, weights): """Create a 1D grid of evaluation points.""" kde = self._fit(x, weights) bw = np.sqrt(kde.covariance.squeeze()) grid = self._define_support_grid( x, bw, self.cut, self.clip, self.gridsize ) return grid def _define_support_bivariate(self, x1, x2, weights): """Create a 2D grid of evaluation points.""" clip = self.clip if clip[0] is None or np.isscalar(clip[0]): clip = (clip, clip) kde = self._fit([x1, x2], weights) bw = np.sqrt(np.diag(kde.covariance).squeeze()) grid1 = self._define_support_grid( x1, bw[0], self.cut, clip[0], self.gridsize ) grid2 = self._define_support_grid( x2, bw[1], self.cut, clip[1], self.gridsize ) return grid1, grid2 def define_support(self, x1, x2=None, weights=None, cache=True): """Create the evaluation grid for a given data set.""" if x2 is None: support = self._define_support_univariate(x1, weights) else: support = self._define_support_bivariate(x1, x2, weights) if cache: self.support = support return support def _fit(self, fit_data, weights=None): """Fit the scipy kde while adding bw_adjust logic and version check.""" fit_kws = {"bw_method": self.bw_method} if weights is not None: fit_kws["weights"] = weights kde = gaussian_kde(fit_data, **fit_kws) kde.set_bandwidth(kde.factor * self.bw_adjust) return kde def _eval_univariate(self, x, weights=None): """Fit and evaluate a univariate on univariate data.""" support = self.support if support is None: support = self.define_support(x, cache=False) kde = self._fit(x, weights) if self.cumulative: s_0 = support[0] density = np.array([ kde.integrate_box_1d(s_0, s_i) for s_i in support ]) else: density = kde(support) return density, support def _eval_bivariate(self, x1, x2, weights=None): """Fit and evaluate a univariate on bivariate data.""" support = self.support if support is None: support = self.define_support(x1, x2, cache=False) kde = self._fit([x1, x2], weights) if self.cumulative: grid1, grid2 = support density = np.zeros((grid1.size, grid2.size)) p0 = grid1.min(), grid2.min() for i, xi in enumerate(grid1): for j, xj in enumerate(grid2): density[i, j] = kde.integrate_box(p0, (xi, xj)) else: xx1, xx2 = np.meshgrid(*support) density = kde([xx1.ravel(), xx2.ravel()]).reshape(xx1.shape) return density, support def __call__(self, x1, x2=None, weights=None): """Fit and evaluate on univariate or bivariate data.""" if x2 is None: return self._eval_univariate(x1, weights) else: return self._eval_bivariate(x1, x2, weights) # Note: we no longer use this for univariate histograms in histplot, # preferring _stats.Hist. We'll deprecate this once we have a bivariate Stat class. class Histogram: """Univariate and bivariate histogram estimator.""" def __init__( self, stat="count", bins="auto", binwidth=None, binrange=None, discrete=False, cumulative=False, ): """Initialize the estimator with its parameters. Parameters ---------- stat : str Aggregate statistic to compute in each bin. - `count`: show the number of observations in each bin - `frequency`: show the number of observations divided by the bin width - `probability` or `proportion`: normalize such that bar heights sum to 1 - `percent`: normalize such that bar heights sum to 100 - `density`: normalize such that the total area of the histogram equals 1 bins : str, number, vector, or a pair of such values Generic bin parameter that can be the name of a reference rule, the number of bins, or the breaks of the bins. Passed to :func:`numpy.histogram_bin_edges`. binwidth : number or pair of numbers Width of each bin, overrides ``bins`` but can be used with ``binrange``. binrange : pair of numbers or a pair of pairs Lowest and highest value for bin edges; can be used either with ``bins`` or ``binwidth``. Defaults to data extremes. discrete : bool or pair of bools If True, set ``binwidth`` and ``binrange`` such that bin edges cover integer values in the dataset. cumulative : bool If True, return the cumulative statistic. """ stat_choices = [ "count", "frequency", "density", "probability", "proportion", "percent", ] _check_argument("stat", stat_choices, stat) self.stat = stat self.bins = bins self.binwidth = binwidth self.binrange = binrange self.discrete = discrete self.cumulative = cumulative self.bin_kws = None def _define_bin_edges(self, x, weights, bins, binwidth, binrange, discrete): """Inner function that takes bin parameters as arguments.""" if binrange is None: start, stop = x.min(), x.max() else: start, stop = binrange if discrete: bin_edges = np.arange(start - .5, stop + 1.5) elif binwidth is not None: step = binwidth bin_edges = np.arange(start, stop + step, step) # Handle roundoff error (maybe there is a less clumsy way?) if bin_edges.max() < stop or len(bin_edges) < 2: bin_edges = np.append(bin_edges, bin_edges.max() + step) else: bin_edges = np.histogram_bin_edges( x, bins, binrange, weights, ) return bin_edges def define_bin_params(self, x1, x2=None, weights=None, cache=True): """Given data, return numpy.histogram parameters to define bins.""" if x2 is None: bin_edges = self._define_bin_edges( x1, weights, self.bins, self.binwidth, self.binrange, self.discrete, ) if isinstance(self.bins, (str, Number)): n_bins = len(bin_edges) - 1 bin_range = bin_edges.min(), bin_edges.max() bin_kws = dict(bins=n_bins, range=bin_range) else: bin_kws = dict(bins=bin_edges) else: bin_edges = [] for i, x in enumerate([x1, x2]): # Resolve out whether bin parameters are shared # or specific to each variable bins = self.bins if not bins or isinstance(bins, (str, Number)): pass elif isinstance(bins[i], str): bins = bins[i] elif len(bins) == 2: bins = bins[i] binwidth = self.binwidth if binwidth is None: pass elif not isinstance(binwidth, Number): binwidth = binwidth[i] binrange = self.binrange if binrange is None: pass elif not isinstance(binrange[0], Number): binrange = binrange[i] discrete = self.discrete if not isinstance(discrete, bool): discrete = discrete[i] # Define the bins for this variable bin_edges.append(self._define_bin_edges( x, weights, bins, binwidth, binrange, discrete, )) bin_kws = dict(bins=tuple(bin_edges)) if cache: self.bin_kws = bin_kws return bin_kws def _eval_bivariate(self, x1, x2, weights): """Inner function for histogram of two variables.""" bin_kws = self.bin_kws if bin_kws is None: bin_kws = self.define_bin_params(x1, x2, cache=False) density = self.stat == "density" hist, *bin_edges = np.histogram2d( x1, x2, **bin_kws, weights=weights, density=density ) area = np.outer( np.diff(bin_edges[0]), np.diff(bin_edges[1]), ) if self.stat == "probability" or self.stat == "proportion": hist = hist.astype(float) / hist.sum() elif self.stat == "percent": hist = hist.astype(float) / hist.sum() * 100 elif self.stat == "frequency": hist = hist.astype(float) / area if self.cumulative: if self.stat in ["density", "frequency"]: hist = (hist * area).cumsum(axis=0).cumsum(axis=1) else: hist = hist.cumsum(axis=0).cumsum(axis=1) return hist, bin_edges def _eval_univariate(self, x, weights): """Inner function for histogram of one variable.""" bin_kws = self.bin_kws if bin_kws is None: bin_kws = self.define_bin_params(x, weights=weights, cache=False) density = self.stat == "density" hist, bin_edges = np.histogram( x, **bin_kws, weights=weights, density=density, ) if self.stat == "probability" or self.stat == "proportion": hist = hist.astype(float) / hist.sum() elif self.stat == "percent": hist = hist.astype(float) / hist.sum() * 100 elif self.stat == "frequency": hist = hist.astype(float) / np.diff(bin_edges) if self.cumulative: if self.stat in ["density", "frequency"]: hist = (hist * np.diff(bin_edges)).cumsum() else: hist = hist.cumsum() return hist, bin_edges def __call__(self, x1, x2=None, weights=None): """Count the occurrences in each bin, maybe normalize.""" if x2 is None: return self._eval_univariate(x1, weights) else: return self._eval_bivariate(x1, x2, weights) class ECDF: """Univariate empirical cumulative distribution estimator.""" def __init__(self, stat="proportion", complementary=False): """Initialize the class with its parameters Parameters ---------- stat : {{"proportion", "percent", "count"}} Distribution statistic to compute. complementary : bool If True, use the complementary CDF (1 - CDF) """ _check_argument("stat", ["count", "percent", "proportion"], stat) self.stat = stat self.complementary = complementary def _eval_bivariate(self, x1, x2, weights): """Inner function for ECDF of two variables.""" raise NotImplementedError("Bivariate ECDF is not implemented") def _eval_univariate(self, x, weights): """Inner function for ECDF of one variable.""" sorter = x.argsort() x = x[sorter] weights = weights[sorter] y = weights.cumsum() if self.stat in ["percent", "proportion"]: y = y / y.max() if self.stat == "percent": y = y * 100 x = np.r_[-np.inf, x] y = np.r_[0, y] if self.complementary: y = y.max() - y return y, x def __call__(self, x1, x2=None, weights=None): """Return proportion or count of observations below each sorted datapoint.""" x1 = np.asarray(x1) if weights is None: weights = np.ones_like(x1) else: weights = np.asarray(weights) if x2 is None: return self._eval_univariate(x1, weights) else: return self._eval_bivariate(x1, x2, weights) class EstimateAggregator: def __init__(self, estimator, errorbar=None, **boot_kws): """ Data aggregator that produces an estimate and error bar interval. Parameters ---------- estimator : callable or string Function (or method name) that maps a vector to a scalar. errorbar : string, (string, number) tuple, or callable Name of errorbar method (either "ci", "pi", "se", or "sd"), or a tuple with a method name and a level parameter, or a function that maps from a vector to a (min, max) interval, or None to hide errorbar. See the :doc:`errorbar tutorial
` for more information. boot_kws Additional keywords are passed to bootstrap when error_method is "ci". """ self.estimator = estimator method, level = _validate_errorbar_arg(errorbar) self.error_method = method self.error_level = level self.boot_kws = boot_kws def __call__(self, data, var): """Aggregate over `var` column of `data` with estimate and error interval.""" vals = data[var] if callable(self.estimator): # You would think we could pass to vals.agg, and yet: # https://github.com/mwaskom/seaborn/issues/2943 estimate = self.estimator(vals) else: estimate = vals.agg(self.estimator) # Options that produce no error bars if self.error_method is None: err_min = err_max = np.nan elif len(data) <= 1: err_min = err_max = np.nan # Generic errorbars from user-supplied function elif callable(self.error_method): err_min, err_max = self.error_method(vals) # Parametric options elif self.error_method == "sd": half_interval = vals.std() * self.error_level err_min, err_max = estimate - half_interval, estimate + half_interval elif self.error_method == "se": half_interval = vals.sem() * self.error_level err_min, err_max = estimate - half_interval, estimate + half_interval # Nonparametric options elif self.error_method == "pi": err_min, err_max = _percentile_interval(vals, self.error_level) elif self.error_method == "ci": units = data.get("units", None) boots = bootstrap(vals, units=units, func=self.estimator, **self.boot_kws) err_min, err_max = _percentile_interval(boots, self.error_level) return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max}) class WeightedAggregator: def __init__(self, estimator, errorbar=None, **boot_kws): """ Data aggregator that produces a weighted estimate and error bar interval. Parameters ---------- estimator : string Function (or method name) that maps a vector to a scalar. Currently supports only "mean". errorbar : string or (string, number) tuple Name of errorbar method or a tuple with a method name and a level parameter. Currently the only supported method is "ci". boot_kws Additional keywords are passed to bootstrap when error_method is "ci". """ if estimator != "mean": # Note that, while other weighted estimators may make sense (e.g. median), # I'm not aware of an implementation in our dependencies. We can add one # in seaborn later, if there is sufficient interest. For now, limit to mean. raise ValueError(f"Weighted estimator must be 'mean', not {estimator!r}.") self.estimator = estimator method, level = _validate_errorbar_arg(errorbar) if method is not None and method != "ci": # As with the estimator, weighted 'sd' or 'pi' error bars may make sense. # But we'll keep things simple for now and limit to (bootstrap) CI. raise ValueError(f"Error bar method must be 'ci', not {method!r}.") self.error_method = method self.error_level = level self.boot_kws = boot_kws def __call__(self, data, var): """Aggregate over `var` column of `data` with estimate and error interval.""" vals = data[var] weights = data["weight"] estimate = np.average(vals, weights=weights) if self.error_method == "ci" and len(data) > 1: def error_func(x, w): return np.average(x, weights=w) boots = bootstrap(vals, weights, func=error_func, **self.boot_kws) err_min, err_max = _percentile_interval(boots, self.error_level) else: err_min = err_max = np.nan return pd.Series({var: estimate, f"{var}min": err_min, f"{var}max": err_max}) class LetterValues: def __init__(self, k_depth, outlier_prop, trust_alpha): """ Compute percentiles of a distribution using various tail stopping rules. Parameters ---------- k_depth: "tukey", "proportion", "trustworthy", or "full" Stopping rule for choosing tail percentiled to show: - tukey: Show a similar number of outliers as in a conventional boxplot. - proportion: Show approximately `outlier_prop` outliers. - trust_alpha: Use `trust_alpha` level for most extreme tail percentile. outlier_prop: float Parameter for `k_depth="proportion"` setting the expected outlier rate. trust_alpha: float Parameter for `k_depth="trustworthy"` setting the confidence threshold. Notes ----- Based on the proposal in this paper: https://vita.had.co.nz/papers/letter-value-plot.pdf """ k_options = ["tukey", "proportion", "trustworthy", "full"] if isinstance(k_depth, str): _check_argument("k_depth", k_options, k_depth) elif not isinstance(k_depth, int): err = ( "The `k_depth` parameter must be either an integer or string " f"(one of {k_options}), not {k_depth!r}." ) raise TypeError(err) self.k_depth = k_depth self.outlier_prop = outlier_prop self.trust_alpha = trust_alpha def _compute_k(self, n): # Select the depth, i.e. number of boxes to draw, based on the method if self.k_depth == "full": # extend boxes to 100% of the data k = int(np.log2(n)) + 1 elif self.k_depth == "tukey": # This results with 5-8 points in each tail k = int(np.log2(n)) - 3 elif self.k_depth == "proportion": k = int(np.log2(n)) - int(np.log2(n * self.outlier_prop)) + 1 elif self.k_depth == "trustworthy": normal_quantile_func = np.vectorize(NormalDist().inv_cdf) point_conf = 2 * normal_quantile_func(1 - self.trust_alpha / 2) ** 2 k = int(np.log2(n / point_conf)) + 1 else: # Allow having k directly specified as input k = int(self.k_depth) return max(k, 1) def __call__(self, x): """Evaluate the letter values.""" k = self._compute_k(len(x)) exp = np.arange(k + 1, 1, -1), np.arange(2, k + 2) levels = k + 1 - np.concatenate([exp[0], exp[1][1:]]) percentiles = 100 * np.concatenate([0.5 ** exp[0], 1 - 0.5 ** exp[1]]) if self.k_depth == "full": percentiles[0] = 0 percentiles[-1] = 100 values = np.percentile(x, percentiles) fliers = np.asarray(x[(x < values.min()) | (x > values.max())]) median = np.percentile(x, 50) return { "k": k, "levels": levels, "percs": percentiles, "values": values, "fliers": fliers, "median": median, } def _percentile_interval(data, width): """Return a percentile interval from data of a given width.""" edge = (100 - width) / 2 percentiles = edge, 100 - edge return np.nanpercentile(data, percentiles) def _validate_errorbar_arg(arg): """Check type and value of errorbar argument and assign default level.""" DEFAULT_LEVELS = { "ci": 95, "pi": 95, "se": 1, "sd": 1, } usage = "`errorbar` must be a callable, string, or (string, number) tuple" if arg is None: return None, None elif callable(arg): return arg, None elif isinstance(arg, str): method = arg level = DEFAULT_LEVELS.get(method, None) else: try: method, level = arg except (ValueError, TypeError) as err: raise err.__class__(usage) from err _check_argument("errorbar", list(DEFAULT_LEVELS), method) if level is not None and not isinstance(level, Number): raise TypeError(usage) return method, level ================================================ FILE: seaborn/_stats/__init__.py ================================================ ================================================ FILE: seaborn/_stats/aggregation.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar, Callable import pandas as pd from pandas import DataFrame from seaborn._core.scales import Scale from seaborn._core.groupby import GroupBy from seaborn._stats.base import Stat from seaborn._statistics import ( EstimateAggregator, WeightedAggregator, ) from seaborn._core.typing import Vector @dataclass class Agg(Stat): """ Aggregate data along the value axis using given method. Parameters ---------- func : str or callable Name of a :class:`pandas.Series` method or a vector -> scalar function. See Also -------- objects.Est : Aggregation with error bars. Examples -------- .. include:: ../docstrings/objects.Agg.rst """ func: str | Callable[[Vector], float] = "mean" group_by_orient: ClassVar[bool] = True def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: var = {"x": "y", "y": "x"}.get(orient) res = ( groupby .agg(data, {var: self.func}) .dropna(subset=[var]) .reset_index(drop=True) ) return res @dataclass class Est(Stat): """ Calculate a point estimate and error bar interval. For more information about the various `errorbar` choices, see the :doc:`errorbar tutorial `. Additional variables: - **weight**: When passed to a layer that uses this stat, a weighted estimate will be computed. Note that use of weights currently limits the choice of function and error bar method to `"mean"` and `"ci"`, respectively. Parameters ---------- func : str or callable Name of a :class:`numpy.ndarray` method or a vector -> scalar function. errorbar : str, (str, float) tuple, or callable Name of errorbar method (one of "ci", "pi", "se" or "sd"), or a tuple with a method name and a level parameter, or a function that maps from a vector to a (min, max) interval. n_boot : int Number of bootstrap samples to draw for "ci" errorbars. seed : int Seed for the PRNG used to draw bootstrap samples. Examples -------- .. include:: ../docstrings/objects.Est.rst """ func: str | Callable[[Vector], float] = "mean" errorbar: str | tuple[str, float] = ("ci", 95) n_boot: int = 1000 seed: int | None = None group_by_orient: ClassVar[bool] = True def _process( self, data: DataFrame, var: str, estimator: EstimateAggregator ) -> DataFrame: # Needed because GroupBy.apply assumes func is DataFrame -> DataFrame # which we could probably make more general to allow Series return res = estimator(data, var) return pd.DataFrame([res]) def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: boot_kws = {"n_boot": self.n_boot, "seed": self.seed} if "weight" in data: engine = WeightedAggregator(self.func, self.errorbar, **boot_kws) else: engine = EstimateAggregator(self.func, self.errorbar, **boot_kws) var = {"x": "y", "y": "x"}[orient] res = ( groupby .apply(data, self._process, var, engine) .dropna(subset=[var]) .reset_index(drop=True) ) res = res.fillna({f"{var}min": res[var], f"{var}max": res[var]}) return res @dataclass class Rolling(Stat): ... def __call__(self, data, groupby, orient, scales): ... ================================================ FILE: seaborn/_stats/base.py ================================================ """Base module for statistical transformations.""" from __future__ import annotations from collections.abc import Iterable from dataclasses import dataclass from typing import ClassVar, Any import warnings from typing import TYPE_CHECKING if TYPE_CHECKING: from pandas import DataFrame from seaborn._core.groupby import GroupBy from seaborn._core.scales import Scale @dataclass class Stat: """Base class for objects that apply statistical transformations.""" # The class supports a partial-function application pattern. The object is # initialized with desired parameters and the result is a callable that # accepts and returns dataframes. # The statistical transformation logic should not add any state to the instance # beyond what is defined with the initialization parameters. # Subclasses can declare whether the orient dimension should be used in grouping # TODO consider whether this should be a parameter. Motivating example: # use the same KDE class violin plots and univariate density estimation. # In the former case, we would expect separate densities for each unique # value on the orient axis, but we would not in the latter case. group_by_orient: ClassVar[bool] = False def _check_param_one_of(self, param: str, options: Iterable[Any]) -> None: """Raise when parameter value is not one of a specified set.""" value = getattr(self, param) if value not in options: *most, last = options option_str = ", ".join(f"{x!r}" for x in most[:-1]) + f" or {last!r}" err = " ".join([ f"The `{param}` parameter for `{self.__class__.__name__}` must be", f"one of {option_str}; not {value!r}.", ]) raise ValueError(err) def _check_grouping_vars( self, param: str, data_vars: list[str], stacklevel: int = 2, ) -> None: """Warn if vars are named in parameter without being present in the data.""" param_vars = getattr(self, param) undefined = set(param_vars) - set(data_vars) if undefined: param = f"{self.__class__.__name__}.{param}" names = ", ".join(f"{x!r}" for x in undefined) msg = f"Undefined variable(s) passed for {param}: {names}." warnings.warn(msg, stacklevel=stacklevel) def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: """Apply statistical transform to data subgroups and return combined result.""" return data ================================================ FILE: seaborn/_stats/counting.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar import numpy as np import pandas as pd from pandas import DataFrame from seaborn._core.groupby import GroupBy from seaborn._core.scales import Scale from seaborn._stats.base import Stat from typing import TYPE_CHECKING if TYPE_CHECKING: from numpy.typing import ArrayLike @dataclass class Count(Stat): """ Count distinct observations within groups. See Also -------- Hist : A more fully-featured transform including binning and/or normalization. Examples -------- .. include:: ../docstrings/objects.Count.rst """ group_by_orient: ClassVar[bool] = True def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: var = {"x": "y", "y": "x"}[orient] res = ( groupby .agg(data.assign(**{var: data[orient]}), {var: len}) .dropna(subset=["x", "y"]) .reset_index(drop=True) ) return res @dataclass class Hist(Stat): """ Bin observations, count them, and optionally normalize or cumulate. Parameters ---------- stat : str Aggregate statistic to compute in each bin: - `count`: the number of observations - `density`: normalize so that the total area of the histogram equals 1 - `percent`: normalize so that bar heights sum to 100 - `probability` or `proportion`: normalize so that bar heights sum to 1 - `frequency`: divide the number of observations by the bin width bins : str, int, or ArrayLike Generic parameter that can be the name of a reference rule, the number of bins, or the bin breaks. Passed to :func:`numpy.histogram_bin_edges`. binwidth : float Width of each bin; overrides `bins` but can be used with `binrange`. Note that if `binwidth` does not evenly divide the bin range, the actual bin width used will be only approximately equal to the parameter value. binrange : (min, max) Lowest and highest value for bin edges; can be used with either `bins` (when a number) or `binwidth`. Defaults to data extremes. common_norm : bool or list of variables When not `False`, the normalization is applied across groups. Use `True` to normalize across all groups, or pass variable name(s) that define normalization groups. common_bins : bool or list of variables When not `False`, the same bins are used for all groups. Use `True` to share bins across all groups, or pass variable name(s) to share within. cumulative : bool If True, cumulate the bin values. discrete : bool If True, set `binwidth` and `binrange` so that bins have unit width and are centered on integer values Notes ----- The choice of bins for computing and plotting a histogram can exert substantial influence on the insights that one is able to draw from the visualization. If the bins are too large, they may erase important features. On the other hand, bins that are too small may be dominated by random variability, obscuring the shape of the true underlying distribution. The default bin size is determined using a reference rule that depends on the sample size and variance. This works well in many cases, (i.e., with "well-behaved" data) but it fails in others. It is always a good to try different bin sizes to be sure that you are not missing something important. This function allows you to specify bins in several different ways, such as by setting the total number of bins to use, the width of each bin, or the specific locations where the bins should break. Examples -------- .. include:: ../docstrings/objects.Hist.rst """ stat: str = "count" bins: str | int | ArrayLike = "auto" binwidth: float | None = None binrange: tuple[float, float] | None = None common_norm: bool | list[str] = True common_bins: bool | list[str] = True cumulative: bool = False discrete: bool = False def __post_init__(self): stat_options = [ "count", "density", "percent", "probability", "proportion", "frequency" ] self._check_param_one_of("stat", stat_options) def _define_bin_edges(self, vals, weight, bins, binwidth, binrange, discrete): """Inner function that takes bin parameters as arguments.""" vals = vals.replace(-np.inf, np.nan).replace(np.inf, np.nan).dropna() if binrange is None: start, stop = vals.min(), vals.max() else: start, stop = binrange if discrete: bin_edges = np.arange(start - .5, stop + 1.5) else: if binwidth is not None: bins = int(round((stop - start) / binwidth)) bin_edges = np.histogram_bin_edges(vals, bins, binrange, weight) # TODO warning or cap on too many bins? return bin_edges def _define_bin_params(self, data, orient, scale_type): """Given data, return numpy.histogram parameters to define bins.""" vals = data[orient] weights = data.get("weight", None) # TODO We'll want this for ordinal / discrete scales too # (Do we need discrete as a parameter or just infer from scale?) discrete = self.discrete or scale_type == "nominal" bin_edges = self._define_bin_edges( vals, weights, self.bins, self.binwidth, self.binrange, discrete, ) if isinstance(self.bins, (str, int)): n_bins = len(bin_edges) - 1 bin_range = bin_edges.min(), bin_edges.max() bin_kws = dict(bins=n_bins, range=bin_range) else: bin_kws = dict(bins=bin_edges) return bin_kws def _get_bins_and_eval(self, data, orient, groupby, scale_type): bin_kws = self._define_bin_params(data, orient, scale_type) return groupby.apply(data, self._eval, orient, bin_kws) def _eval(self, data, orient, bin_kws): vals = data[orient] weights = data.get("weight", None) density = self.stat == "density" hist, edges = np.histogram(vals, **bin_kws, weights=weights, density=density) width = np.diff(edges) center = edges[:-1] + width / 2 return pd.DataFrame({orient: center, "count": hist, "space": width}) def _normalize(self, data): hist = data["count"] if self.stat == "probability" or self.stat == "proportion": hist = hist.astype(float) / hist.sum() elif self.stat == "percent": hist = hist.astype(float) / hist.sum() * 100 elif self.stat == "frequency": hist = hist.astype(float) / data["space"] if self.cumulative: if self.stat in ["density", "frequency"]: hist = (hist * data["space"]).cumsum() else: hist = hist.cumsum() return data.assign(**{self.stat: hist}) def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: scale_type = scales[orient].__class__.__name__.lower() grouping_vars = [str(v) for v in data if v in groupby.order] if not grouping_vars or self.common_bins is True: bin_kws = self._define_bin_params(data, orient, scale_type) data = groupby.apply(data, self._eval, orient, bin_kws) else: if self.common_bins is False: bin_groupby = GroupBy(grouping_vars) else: bin_groupby = GroupBy(self.common_bins) self._check_grouping_vars("common_bins", grouping_vars) data = bin_groupby.apply( data, self._get_bins_and_eval, orient, groupby, scale_type, ) if not grouping_vars or self.common_norm is True: data = self._normalize(data) else: if self.common_norm is False: norm_groupby = GroupBy(grouping_vars) else: norm_groupby = GroupBy(self.common_norm) self._check_grouping_vars("common_norm", grouping_vars) data = norm_groupby.apply(data, self._normalize) other = {"x": "y", "y": "x"}[orient] return data.assign(**{other: data[self.stat]}) ================================================ FILE: seaborn/_stats/density.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import Any, Callable import numpy as np from numpy import ndarray import pandas as pd from pandas import DataFrame try: from scipy.stats import gaussian_kde _no_scipy = False except ImportError: from seaborn.external.kde import gaussian_kde _no_scipy = True from seaborn._core.groupby import GroupBy from seaborn._core.scales import Scale from seaborn._stats.base import Stat @dataclass class KDE(Stat): """ Compute a univariate kernel density estimate. Parameters ---------- bw_adjust : float Factor that multiplicatively scales the value chosen using `bw_method`. Increasing will make the curve smoother. See Notes. bw_method : string, scalar, or callable Method for determining the smoothing bandwidth to use. Passed directly to :class:`scipy.stats.gaussian_kde`; see there for options. common_norm : bool or list of variables If `True`, normalize so that the areas of all curves sums to 1. If `False`, normalize each curve independently. If a list, defines variable(s) to group by and normalize within. common_grid : bool or list of variables If `True`, all curves will share the same evaluation grid. If `False`, each evaluation grid is independent. If a list, defines variable(s) to group by and share a grid within. gridsize : int or None Number of points in the evaluation grid. If None, the density is evaluated at the original datapoints. cut : float Factor, multiplied by the kernel bandwidth, that determines how far the evaluation grid extends past the extreme datapoints. When set to 0, the curve is truncated at the data limits. cumulative : bool If True, estimate a cumulative distribution function. Requires scipy. Notes ----- The *bandwidth*, or standard deviation of the smoothing kernel, is an important parameter. Much like histogram bin width, using the wrong bandwidth can produce a distorted representation. Over-smoothing can erase true features, while under-smoothing can create false ones. The default uses a rule-of-thumb that works best for distributions that are roughly bell-shaped. It is a good idea to check the default by varying `bw_adjust`. Because the smoothing is performed with a Gaussian kernel, the estimated density curve can extend to values that may not make sense. For example, the curve may be drawn over negative values when data that are naturally positive. The `cut` parameter can be used to control the evaluation range, but datasets that have many observations close to a natural boundary may be better served by a different method. Similar distortions may arise when a dataset is naturally discrete or "spiky" (containing many repeated observations of the same value). KDEs will always produce a smooth curve, which could be misleading. The units on the density axis are a common source of confusion. While kernel density estimation produces a probability distribution, the height of the curve at each point gives a density, not a probability. A probability can be obtained only by integrating the density across a range. The curve is normalized so that the integral over all possible values is 1, meaning that the scale of the density axis depends on the data values. If scipy is installed, its cython-accelerated implementation will be used. Examples -------- .. include:: ../docstrings/objects.KDE.rst """ bw_adjust: float = 1 bw_method: str | float | Callable[[gaussian_kde], float] = "scott" common_norm: bool | list[str] = True common_grid: bool | list[str] = True gridsize: int | None = 200 cut: float = 3 cumulative: bool = False def __post_init__(self): if self.cumulative and _no_scipy: raise RuntimeError("Cumulative KDE evaluation requires scipy") def _check_var_list_or_boolean(self, param: str, grouping_vars: Any) -> None: """Do input checks on grouping parameters.""" value = getattr(self, param) if not ( isinstance(value, bool) or (isinstance(value, list) and all(isinstance(v, str) for v in value)) ): param_name = f"{self.__class__.__name__}.{param}" raise TypeError(f"{param_name} must be a boolean or list of strings.") self._check_grouping_vars(param, grouping_vars, stacklevel=3) def _fit(self, data: DataFrame, orient: str) -> gaussian_kde: """Fit and return a KDE object.""" # TODO need to handle singular data fit_kws: dict[str, Any] = {"bw_method": self.bw_method} if "weight" in data: fit_kws["weights"] = data["weight"] kde = gaussian_kde(data[orient], **fit_kws) kde.set_bandwidth(kde.factor * self.bw_adjust) return kde def _get_support(self, data: DataFrame, orient: str) -> ndarray: """Define the grid that the KDE will be evaluated on.""" if self.gridsize is None: return data[orient].to_numpy() kde = self._fit(data, orient) bw = np.sqrt(kde.covariance.squeeze()) gridmin = data[orient].min() - bw * self.cut gridmax = data[orient].max() + bw * self.cut return np.linspace(gridmin, gridmax, self.gridsize) def _fit_and_evaluate( self, data: DataFrame, orient: str, support: ndarray ) -> DataFrame: """Transform single group by fitting a KDE and evaluating on a support grid.""" empty = pd.DataFrame(columns=[orient, "weight", "density"], dtype=float) if len(data) < 2: return empty try: kde = self._fit(data, orient) except np.linalg.LinAlgError: return empty if self.cumulative: s_0 = support[0] density = np.array([kde.integrate_box_1d(s_0, s_i) for s_i in support]) else: density = kde(support) weight = data["weight"].sum() return pd.DataFrame({orient: support, "weight": weight, "density": density}) def _transform( self, data: DataFrame, orient: str, grouping_vars: list[str] ) -> DataFrame: """Transform multiple groups by fitting KDEs and evaluating.""" empty = pd.DataFrame(columns=[*data.columns, "density"], dtype=float) if len(data) < 2: return empty try: support = self._get_support(data, orient) except np.linalg.LinAlgError: return empty grouping_vars = [x for x in grouping_vars if data[x].nunique() > 1] if not grouping_vars: return self._fit_and_evaluate(data, orient, support) groupby = GroupBy(grouping_vars) return groupby.apply(data, self._fit_and_evaluate, orient, support) def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: if "weight" not in data: data = data.assign(weight=1) data = data.dropna(subset=[orient, "weight"]) # Transform each group separately grouping_vars = [str(v) for v in data if v in groupby.order] if not grouping_vars or self.common_grid is True: res = self._transform(data, orient, grouping_vars) else: if self.common_grid is False: grid_vars = grouping_vars else: self._check_var_list_or_boolean("common_grid", grouping_vars) grid_vars = [v for v in self.common_grid if v in grouping_vars] res = ( GroupBy(grid_vars) .apply(data, self._transform, orient, grouping_vars) ) # Normalize, potentially within groups if not grouping_vars or self.common_norm is True: res = res.assign(group_weight=data["weight"].sum()) else: if self.common_norm is False: norm_vars = grouping_vars else: self._check_var_list_or_boolean("common_norm", grouping_vars) norm_vars = [v for v in self.common_norm if v in grouping_vars] res = res.join( data.groupby(norm_vars)["weight"].sum().rename("group_weight"), on=norm_vars, ) res["density"] *= res.eval("weight / group_weight") value = {"x": "y", "y": "x"}[orient] res[value] = res["density"] return res.drop(["weight", "group_weight"], axis=1) ================================================ FILE: seaborn/_stats/order.py ================================================ from __future__ import annotations from dataclasses import dataclass from typing import ClassVar, cast try: from typing import Literal except ImportError: from typing_extensions import Literal # type: ignore import numpy as np from pandas import DataFrame from seaborn._core.scales import Scale from seaborn._core.groupby import GroupBy from seaborn._stats.base import Stat from seaborn.utils import _version_predates # From https://github.com/numpy/numpy/blob/main/numpy/lib/function_base.pyi _MethodKind = Literal[ "inverted_cdf", "averaged_inverted_cdf", "closest_observation", "interpolated_inverted_cdf", "hazen", "weibull", "linear", "median_unbiased", "normal_unbiased", "lower", "higher", "midpoint", "nearest", ] @dataclass class Perc(Stat): """ Replace observations with percentile values. Parameters ---------- k : list of numbers or int If a list of numbers, this gives the percentiles (in [0, 100]) to compute. If an integer, compute `k` evenly-spaced percentiles between 0 and 100. For example, `k=5` computes the 0, 25, 50, 75, and 100th percentiles. method : str Method for interpolating percentiles between observed datapoints. See :func:`numpy.percentile` for valid options and more information. Examples -------- .. include:: ../docstrings/objects.Perc.rst """ k: int | list[float] = 5 method: str = "linear" group_by_orient: ClassVar[bool] = True def _percentile(self, data: DataFrame, var: str) -> DataFrame: k = list(np.linspace(0, 100, self.k)) if isinstance(self.k, int) else self.k method = cast(_MethodKind, self.method) values = data[var].dropna() if _version_predates(np, "1.22"): res = np.percentile(values, k, interpolation=method) # type: ignore else: res = np.percentile(data[var].dropna(), k, method=method) return DataFrame({var: res, "percentile": k}) def __call__( self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale], ) -> DataFrame: var = {"x": "y", "y": "x"}[orient] return groupby.apply(data, self._percentile, var) ================================================ FILE: seaborn/_stats/regression.py ================================================ from __future__ import annotations from dataclasses import dataclass import numpy as np import pandas as pd from seaborn._stats.base import Stat @dataclass class PolyFit(Stat): """ Fit a polynomial of the given order and resample data onto predicted curve. """ # This is a provisional class that is useful for building out functionality. # It may or may not change substantially in form or dissappear as we think # through the organization of the stats subpackage. order: int = 2 gridsize: int = 100 def _fit_predict(self, data): x = data["x"] y = data["y"] if x.nunique() <= self.order: # TODO warn? xx = yy = [] else: p = np.polyfit(x, y, self.order) xx = np.linspace(x.min(), x.max(), self.gridsize) yy = np.polyval(p, xx) return pd.DataFrame(dict(x=xx, y=yy)) # TODO we should have a way of identifying the method that will be applied # and then only define __call__ on a base-class of stats with this pattern def __call__(self, data, groupby, orient, scales): return ( groupby .apply(data.dropna(subset=["x", "y"]), self._fit_predict) ) @dataclass class OLSFit(Stat): ... ================================================ FILE: seaborn/_testing.py ================================================ import numpy as np import matplotlib as mpl from matplotlib.colors import to_rgb, to_rgba from numpy.testing import assert_array_equal USE_PROPS = [ "alpha", "edgecolor", "facecolor", "fill", "hatch", "height", "linestyle", "linewidth", "paths", "xy", "xydata", "sizes", "zorder", ] def assert_artists_equal(list1, list2): assert len(list1) == len(list2) for a1, a2 in zip(list1, list2): assert a1.__class__ == a2.__class__ prop1 = a1.properties() prop2 = a2.properties() for key in USE_PROPS: if key not in prop1: continue v1 = prop1[key] v2 = prop2[key] if key == "paths": for p1, p2 in zip(v1, v2): assert_array_equal(p1.vertices, p2.vertices) assert_array_equal(p1.codes, p2.codes) elif key == "color": v1 = mpl.colors.to_rgba(v1) v2 = mpl.colors.to_rgba(v2) assert v1 == v2 elif isinstance(v1, np.ndarray): assert_array_equal(v1, v2) else: assert v1 == v2 def assert_legends_equal(leg1, leg2): assert leg1.get_title().get_text() == leg2.get_title().get_text() for t1, t2 in zip(leg1.get_texts(), leg2.get_texts()): assert t1.get_text() == t2.get_text() assert_artists_equal( leg1.get_patches(), leg2.get_patches(), ) assert_artists_equal( leg1.get_lines(), leg2.get_lines(), ) def assert_plots_equal(ax1, ax2, labels=True): assert_artists_equal(ax1.patches, ax2.patches) assert_artists_equal(ax1.lines, ax2.lines) assert_artists_equal(ax1.collections, ax2.collections) if labels: assert ax1.get_xlabel() == ax2.get_xlabel() assert ax1.get_ylabel() == ax2.get_ylabel() def assert_colors_equal(a, b, check_alpha=True): def handle_array(x): if isinstance(x, np.ndarray): if x.ndim > 1: x = np.unique(x, axis=0).squeeze() if x.ndim > 1: raise ValueError("Color arrays must be 1 dimensional") return x a = handle_array(a) b = handle_array(b) f = to_rgba if check_alpha else to_rgb assert f(a) == f(b) ================================================ FILE: seaborn/algorithms.py ================================================ """Algorithms to support fitting routines in seaborn plotting functions.""" import numpy as np import warnings def bootstrap(*args, **kwargs): """Resample one or more arrays with replacement and store aggregate values. Positional arguments are a sequence of arrays to bootstrap along the first axis and pass to a summary function. Keyword arguments: n_boot : int, default=10000 Number of iterations axis : int, default=None Will pass axis to ``func`` as a keyword argument. units : array, default=None Array of sampling unit IDs. When used the bootstrap resamples units and then observations within units instead of individual datapoints. func : string or callable, default="mean" Function to call on the args that are passed in. If string, uses as name of function in the numpy namespace. If nans are present in the data, will try to use nan-aware version of named function. seed : Generator | SeedSequence | RandomState | int | None Seed for the random number generator; useful if you want reproducible resamples. Returns ------- boot_dist: array array of bootstrapped statistic values """ # Ensure list of arrays are same length if len(np.unique(list(map(len, args)))) > 1: raise ValueError("All input arrays must have the same length") n = len(args[0]) # Default keyword arguments n_boot = kwargs.get("n_boot", 10000) func = kwargs.get("func", "mean") axis = kwargs.get("axis", None) units = kwargs.get("units", None) random_seed = kwargs.get("random_seed", None) if random_seed is not None: msg = "`random_seed` has been renamed to `seed` and will be removed" warnings.warn(msg) seed = kwargs.get("seed", random_seed) if axis is None: func_kwargs = dict() else: func_kwargs = dict(axis=axis) # Initialize the resampler if isinstance(seed, np.random.RandomState): rng = seed else: rng = np.random.default_rng(seed) # Coerce to arrays args = list(map(np.asarray, args)) if units is not None: units = np.asarray(units) if isinstance(func, str): # Allow named numpy functions f = getattr(np, func) # Try to use nan-aware version of function if necessary missing_data = np.isnan(np.sum(np.column_stack(args))) if missing_data and not func.startswith("nan"): nanf = getattr(np, f"nan{func}", None) if nanf is None: msg = f"Data contain nans but no nan-aware version of `{func}` found" warnings.warn(msg, UserWarning) else: f = nanf else: f = func # Handle numpy changes try: integers = rng.integers except AttributeError: integers = rng.randint # Do the bootstrap if units is not None: return _structured_bootstrap(args, n_boot, units, f, func_kwargs, integers) boot_dist = [] for i in range(int(n_boot)): resampler = integers(0, n, n, dtype=np.intp) # intp is indexing dtype sample = [a.take(resampler, axis=0) for a in args] boot_dist.append(f(*sample, **func_kwargs)) return np.array(boot_dist) def _structured_bootstrap(args, n_boot, units, func, func_kwargs, integers): """Resample units instead of datapoints.""" unique_units = np.unique(units) n_units = len(unique_units) args = [[a[units == unit] for unit in unique_units] for a in args] boot_dist = [] for i in range(int(n_boot)): resampler = integers(0, n_units, n_units, dtype=np.intp) sample = [[a[i] for i in resampler] for a in args] lengths = map(len, sample[0]) resampler = [integers(0, n, n, dtype=np.intp) for n in lengths] sample = [[c.take(r, axis=0) for c, r in zip(a, resampler)] for a in sample] sample = list(map(np.concatenate, sample)) boot_dist.append(func(*sample, **func_kwargs)) return np.array(boot_dist) ================================================ FILE: seaborn/axisgrid.py ================================================ from __future__ import annotations from itertools import product from inspect import signature import warnings from textwrap import dedent import numpy as np import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt from ._base import VectorPlotter, variable_type, categorical_order from ._core.data import handle_data_source from ._compat import share_axis, get_legend_handles from . import utils from .utils import ( adjust_legend_subtitles, set_hls_values, _check_argument, _draw_figure, _disable_autolayout ) from .palettes import color_palette, blend_palette from ._docstrings import ( DocstringComponents, _core_docs, ) __all__ = ["FacetGrid", "PairGrid", "JointGrid", "pairplot", "jointplot"] _param_docs = DocstringComponents.from_nested_components( core=_core_docs["params"], ) class _BaseGrid: """Base class for grids of subplots.""" def set(self, **kwargs): """Set attributes on each subplot Axes.""" for ax in self.axes.flat: if ax is not None: # Handle removed axes ax.set(**kwargs) return self @property def fig(self): """DEPRECATED: prefer the `figure` property.""" # Grid.figure is preferred because it matches the Axes attribute name. # But as the maintanace burden on having this property is minimal, # let's be slow about formally deprecating it. For now just note its deprecation # in the docstring; add a warning in version 0.13, and eventually remove it. return self._figure @property def figure(self): """Access the :class:`matplotlib.figure.Figure` object underlying the grid.""" return self._figure def apply(self, func, *args, **kwargs): """ Pass the grid to a user-supplied function and return self. The `func` must accept an object of this type for its first positional argument. Additional arguments are passed through. The return value of `func` is ignored; this method returns self. See the `pipe` method if you want the return value. Added in v0.12.0. """ func(self, *args, **kwargs) return self def pipe(self, func, *args, **kwargs): """ Pass the grid to a user-supplied function and return its value. The `func` must accept an object of this type for its first positional argument. Additional arguments are passed through. The return value of `func` becomes the return value of this method. See the `apply` method if you want to return self instead. Added in v0.12.0. """ return func(self, *args, **kwargs) def savefig(self, *args, **kwargs): """ Save an image of the plot. This wraps :meth:`matplotlib.figure.Figure.savefig`, using bbox_inches="tight" by default. Parameters are passed through to the matplotlib function. """ kwargs = kwargs.copy() kwargs.setdefault("bbox_inches", "tight") self.figure.savefig(*args, **kwargs) class Grid(_BaseGrid): """A grid that can have multiple subplots and an external legend.""" _margin_titles = False _legend_out = True def __init__(self): self._tight_layout_rect = [0, 0, 1, 1] self._tight_layout_pad = None # This attribute is set externally and is a hack to handle newer functions that # don't add proxy artists onto the Axes. We need an overall cleaner approach. self._extract_legend_handles = False def tight_layout(self, *args, **kwargs): """Call fig.tight_layout within rect that exclude the legend.""" kwargs = kwargs.copy() kwargs.setdefault("rect", self._tight_layout_rect) if self._tight_layout_pad is not None: kwargs.setdefault("pad", self._tight_layout_pad) self._figure.tight_layout(*args, **kwargs) return self def add_legend(self, legend_data=None, title=None, label_order=None, adjust_subtitles=False, **kwargs): """Draw a legend, maybe placing it outside axes and resizing the figure. Parameters ---------- legend_data : dict Dictionary mapping label names (or two-element tuples where the second element is a label name) to matplotlib artist handles. The default reads from ``self._legend_data``. title : string Title for the legend. The default reads from ``self._hue_var``. label_order : list of labels The order that the legend entries should appear in. The default reads from ``self.hue_names``. adjust_subtitles : bool If True, modify entries with invisible artists to left-align the labels and set the font size to that of a title. kwargs : key, value pairings Other keyword arguments are passed to the underlying legend methods on the Figure or Axes object. Returns ------- self : Grid instance Returns self for easy chaining. """ # Find the data for the legend if legend_data is None: legend_data = self._legend_data if label_order is None: if self.hue_names is None: label_order = list(legend_data.keys()) else: label_order = list(map(utils.to_utf8, self.hue_names)) blank_handle = mpl.patches.Patch(alpha=0, linewidth=0) handles = [legend_data.get(lab, blank_handle) for lab in label_order] title = self._hue_var if title is None else title title_size = mpl.rcParams["legend.title_fontsize"] # Unpack nested labels from a hierarchical legend labels = [] for entry in label_order: if isinstance(entry, tuple): _, label = entry else: label = entry labels.append(label) # Set default legend kwargs kwargs.setdefault("scatterpoints", 1) if self._legend_out: kwargs.setdefault("frameon", False) kwargs.setdefault("loc", "center right") # Draw a full-figure legend outside the grid figlegend = self._figure.legend(handles, labels, **kwargs) self._legend = figlegend figlegend.set_title(title, prop={"size": title_size}) if adjust_subtitles: adjust_legend_subtitles(figlegend) # Draw the plot to set the bounding boxes correctly _draw_figure(self._figure) # Calculate and set the new width of the figure so the legend fits legend_width = figlegend.get_window_extent().width / self._figure.dpi fig_width, fig_height = self._figure.get_size_inches() self._figure.set_size_inches(fig_width + legend_width, fig_height) # Draw the plot again to get the new transformations _draw_figure(self._figure) # Now calculate how much space we need on the right side legend_width = figlegend.get_window_extent().width / self._figure.dpi space_needed = legend_width / (fig_width + legend_width) margin = .04 if self._margin_titles else .01 self._space_needed = margin + space_needed right = 1 - self._space_needed # Place the subplot axes to give space for the legend self._figure.subplots_adjust(right=right) self._tight_layout_rect[2] = right else: # Draw a legend in the first axis ax = self.axes.flat[0] kwargs.setdefault("loc", "best") leg = ax.legend(handles, labels, **kwargs) leg.set_title(title, prop={"size": title_size}) self._legend = leg if adjust_subtitles: adjust_legend_subtitles(leg) return self def _update_legend_data(self, ax): """Extract the legend data from an axes object and save it.""" data = {} # Get data directly from the legend, which is necessary # for newer functions that don't add labeled proxy artists if ax.legend_ is not None and self._extract_legend_handles: handles = get_legend_handles(ax.legend_) labels = [t.get_text() for t in ax.legend_.texts] data.update({label: handle for handle, label in zip(handles, labels)}) handles, labels = ax.get_legend_handles_labels() data.update({label: handle for handle, label in zip(handles, labels)}) self._legend_data.update(data) # Now clear the legend ax.legend_ = None def _get_palette(self, data, hue, hue_order, palette): """Get a list of colors for the hue variable.""" if hue is None: palette = color_palette(n_colors=1) else: hue_names = categorical_order(data[hue], hue_order) n_colors = len(hue_names) # By default use either the current color palette or HUSL if palette is None: current_palette = utils.get_color_cycle() if n_colors > len(current_palette): colors = color_palette("husl", n_colors) else: colors = color_palette(n_colors=n_colors) # Allow for palette to map from hue variable names elif isinstance(palette, dict): color_names = [palette[h] for h in hue_names] colors = color_palette(color_names, n_colors) # Otherwise act as if we just got a list of colors else: colors = color_palette(palette, n_colors) palette = color_palette(colors, n_colors) return palette @property def legend(self): """The :class:`matplotlib.legend.Legend` object, if present.""" try: return self._legend except AttributeError: return None def tick_params(self, axis='both', **kwargs): """Modify the ticks, tick labels, and gridlines. Parameters ---------- axis : {'x', 'y', 'both'} The axis on which to apply the formatting. kwargs : keyword arguments Additional keyword arguments to pass to :meth:`matplotlib.axes.Axes.tick_params`. Returns ------- self : Grid instance Returns self for easy chaining. """ for ax in self.figure.axes: ax.tick_params(axis=axis, **kwargs) return self _facet_docs = dict( data=dedent("""\ data : DataFrame Tidy ("long-form") dataframe where each column is a variable and each row is an observation.\ """), rowcol=dedent("""\ row, col : vectors or keys in ``data`` Variables that define subsets to plot on different facets.\ """), rowcol_order=dedent("""\ {row,col}_order : vector of strings Specify the order in which levels of the ``row`` and/or ``col`` variables appear in the grid of subplots.\ """), col_wrap=dedent("""\ col_wrap : int "Wrap" the column variable at this width, so that the column facets span multiple rows. Incompatible with a ``row`` facet.\ """), share_xy=dedent("""\ share{x,y} : bool, 'col', or 'row' optional If true, the facets will share y axes across columns and/or x axes across rows.\ """), height=dedent("""\ height : scalar Height (in inches) of each facet. See also: ``aspect``.\ """), aspect=dedent("""\ aspect : scalar Aspect ratio of each facet, so that ``aspect * height`` gives the width of each facet in inches.\ """), palette=dedent("""\ palette : palette name, list, or dict Colors to use for the different levels of the ``hue`` variable. Should be something that can be interpreted by :func:`color_palette`, or a dictionary mapping hue levels to matplotlib colors.\ """), legend_out=dedent("""\ legend_out : bool If ``True``, the figure size will be extended, and the legend will be drawn outside the plot on the center right.\ """), margin_titles=dedent("""\ margin_titles : bool If ``True``, the titles for the row variable are drawn to the right of the last column. This option is experimental and may not work in all cases.\ """), facet_kws=dedent("""\ facet_kws : dict Additional parameters passed to :class:`FacetGrid`. """), ) class FacetGrid(Grid): """Multi-plot grid for plotting conditional relationships.""" def __init__( self, data, *, row=None, col=None, hue=None, col_wrap=None, sharex=True, sharey=True, height=3, aspect=1, palette=None, row_order=None, col_order=None, hue_order=None, hue_kws=None, dropna=False, legend_out=True, despine=True, margin_titles=False, xlim=None, ylim=None, subplot_kws=None, gridspec_kws=None, ): super().__init__() data = handle_data_source(data) # Determine the hue facet layer information hue_var = hue if hue is None: hue_names = None else: hue_names = categorical_order(data[hue], hue_order) colors = self._get_palette(data, hue, hue_order, palette) # Set up the lists of names for the row and column facet variables if row is None: row_names = [] else: row_names = categorical_order(data[row], row_order) if col is None: col_names = [] else: col_names = categorical_order(data[col], col_order) # Additional dict of kwarg -> list of values for mapping the hue var hue_kws = hue_kws if hue_kws is not None else {} # Make a boolean mask that is True anywhere there is an NA # value in one of the faceting variables, but only if dropna is True none_na = np.zeros(len(data), bool) if dropna: row_na = none_na if row is None else data[row].isnull() col_na = none_na if col is None else data[col].isnull() hue_na = none_na if hue is None else data[hue].isnull() not_na = ~(row_na | col_na | hue_na) else: not_na = ~none_na # Compute the grid shape ncol = 1 if col is None else len(col_names) nrow = 1 if row is None else len(row_names) self._n_facets = ncol * nrow self._col_wrap = col_wrap if col_wrap is not None: if row is not None: err = "Cannot use `row` and `col_wrap` together." raise ValueError(err) ncol = col_wrap nrow = int(np.ceil(len(col_names) / col_wrap)) self._ncol = ncol self._nrow = nrow # Calculate the base figure size # This can get stretched later by a legend # TODO this doesn't account for axis labels figsize = (ncol * height * aspect, nrow * height) # Validate some inputs if col_wrap is not None: margin_titles = False # Build the subplot keyword dictionary subplot_kws = {} if subplot_kws is None else subplot_kws.copy() gridspec_kws = {} if gridspec_kws is None else gridspec_kws.copy() if xlim is not None: subplot_kws["xlim"] = xlim if ylim is not None: subplot_kws["ylim"] = ylim # --- Initialize the subplot grid with _disable_autolayout(): fig = plt.figure(figsize=figsize) if col_wrap is None: kwargs = dict(squeeze=False, sharex=sharex, sharey=sharey, subplot_kw=subplot_kws, gridspec_kw=gridspec_kws) axes = fig.subplots(nrow, ncol, **kwargs) if col is None and row is None: axes_dict = {} elif col is None: axes_dict = dict(zip(row_names, axes.flat)) elif row is None: axes_dict = dict(zip(col_names, axes.flat)) else: facet_product = product(row_names, col_names) axes_dict = dict(zip(facet_product, axes.flat)) else: # If wrapping the col variable we need to make the grid ourselves if gridspec_kws: warnings.warn("`gridspec_kws` ignored when using `col_wrap`") n_axes = len(col_names) axes = np.empty(n_axes, object) axes[0] = fig.add_subplot(nrow, ncol, 1, **subplot_kws) if sharex: subplot_kws["sharex"] = axes[0] if sharey: subplot_kws["sharey"] = axes[0] for i in range(1, n_axes): axes[i] = fig.add_subplot(nrow, ncol, i + 1, **subplot_kws) axes_dict = dict(zip(col_names, axes)) # --- Set up the class attributes # Attributes that are part of the public API but accessed through # a property so that Sphinx adds them to the auto class doc self._figure = fig self._axes = axes self._axes_dict = axes_dict self._legend = None # Public attributes that aren't explicitly documented # (It's not obvious that having them be public was a good idea) self.data = data self.row_names = row_names self.col_names = col_names self.hue_names = hue_names self.hue_kws = hue_kws # Next the private variables self._nrow = nrow self._row_var = row self._ncol = ncol self._col_var = col self._margin_titles = margin_titles self._margin_titles_texts = [] self._col_wrap = col_wrap self._hue_var = hue_var self._colors = colors self._legend_out = legend_out self._legend_data = {} self._x_var = None self._y_var = None self._sharex = sharex self._sharey = sharey self._dropna = dropna self._not_na = not_na # --- Make the axes look good self.set_titles() self.tight_layout() if despine: self.despine() if sharex in [True, 'col']: for ax in self._not_bottom_axes: for label in ax.get_xticklabels(): label.set_visible(False) ax.xaxis.offsetText.set_visible(False) ax.xaxis.label.set_visible(False) if sharey in [True, 'row']: for ax in self._not_left_axes: for label in ax.get_yticklabels(): label.set_visible(False) ax.yaxis.offsetText.set_visible(False) ax.yaxis.label.set_visible(False) __init__.__doc__ = dedent("""\ Initialize the matplotlib figure and FacetGrid object. This class maps a dataset onto multiple axes arrayed in a grid of rows and columns that correspond to *levels* of variables in the dataset. The plots it produces are often called "lattice", "trellis", or "small-multiple" graphics. It can also represent levels of a third variable with the ``hue`` parameter, which plots different subsets of data in different colors. This uses color to resolve elements on a third dimension, but only draws subsets on top of each other and will not tailor the ``hue`` parameter for the specific visualization the way that axes-level functions that accept ``hue`` will. The basic workflow is to initialize the :class:`FacetGrid` object with the dataset and the variables that are used to structure the grid. Then one or more plotting functions can be applied to each subset by calling :meth:`FacetGrid.map` or :meth:`FacetGrid.map_dataframe`. Finally, the plot can be tweaked with other methods to do things like change the axis labels, use different ticks, or add a legend. See the detailed code examples below for more information. .. warning:: When using seaborn functions that infer semantic mappings from a dataset, care must be taken to synchronize those mappings across facets (e.g., by defining the ``hue`` mapping with a palette dict or setting the data type of the variables to ``category``). In most cases, it will be better to use a figure-level function (e.g. :func:`relplot` or :func:`catplot`) than to use :class:`FacetGrid` directly. See the :ref:`tutorial ` for more information. Parameters ---------- {data} row, col, hue : strings Variables that define subsets of the data, which will be drawn on separate facets in the grid. See the ``{{var}}_order`` parameters to control the order of levels of this variable. {col_wrap} {share_xy} {height} {aspect} {palette} {{row,col,hue}}_order : lists Order for the levels of the faceting variables. By default, this will be the order that the levels appear in ``data`` or, if the variables are pandas categoricals, the category order. hue_kws : dictionary of param -> list of values mapping Other keyword arguments to insert into the plotting call to let other plot attributes vary across levels of the hue variable (e.g. the markers in a scatterplot). {legend_out} despine : boolean Remove the top and right spines from the plots. {margin_titles} {{x, y}}lim: tuples Limits for each of the axes on each facet (only relevant when share{{x, y}} is True). subplot_kws : dict Dictionary of keyword arguments passed to matplotlib subplot(s) methods. gridspec_kws : dict Dictionary of keyword arguments passed to :class:`matplotlib.gridspec.GridSpec` (via :meth:`matplotlib.figure.Figure.subplots`). Ignored if ``col_wrap`` is not ``None``. See Also -------- PairGrid : Subplot grid for plotting pairwise relationships relplot : Combine a relational plot and a :class:`FacetGrid` displot : Combine a distribution plot and a :class:`FacetGrid` catplot : Combine a categorical plot and a :class:`FacetGrid` lmplot : Combine a regression plot and a :class:`FacetGrid` Examples -------- .. note:: These examples use seaborn functions to demonstrate some of the advanced features of the class, but in most cases you will want to use figue-level functions (e.g. :func:`displot`, :func:`relplot`) to make the plots shown here. .. include:: ../docstrings/FacetGrid.rst """).format(**_facet_docs) def facet_data(self): """Generator for name indices and data subsets for each facet. Yields ------ (i, j, k), data_ijk : tuple of ints, DataFrame The ints provide an index into the {row, col, hue}_names attribute, and the dataframe contains a subset of the full data corresponding to each facet. The generator yields subsets that correspond with the self.axes.flat iterator, or self.axes[i, j] when `col_wrap` is None. """ data = self.data # Construct masks for the row variable if self.row_names: row_masks = [data[self._row_var] == n for n in self.row_names] else: row_masks = [np.repeat(True, len(self.data))] # Construct masks for the column variable if self.col_names: col_masks = [data[self._col_var] == n for n in self.col_names] else: col_masks = [np.repeat(True, len(self.data))] # Construct masks for the hue variable if self.hue_names: hue_masks = [data[self._hue_var] == n for n in self.hue_names] else: hue_masks = [np.repeat(True, len(self.data))] # Here is the main generator loop for (i, row), (j, col), (k, hue) in product(enumerate(row_masks), enumerate(col_masks), enumerate(hue_masks)): data_ijk = data[row & col & hue & self._not_na] yield (i, j, k), data_ijk def map(self, func, *args, **kwargs): """Apply a plotting function to each facet's subset of the data. Parameters ---------- func : callable A plotting function that takes data and keyword arguments. It must plot to the currently active matplotlib Axes and take a `color` keyword argument. If faceting on the `hue` dimension, it must also take a `label` keyword argument. args : strings Column names in self.data that identify variables with data to plot. The data for each variable is passed to `func` in the order the variables are specified in the call. kwargs : keyword arguments All keyword arguments are passed to the plotting function. Returns ------- self : object Returns self. """ # If color was a keyword argument, grab it here kw_color = kwargs.pop("color", None) # How we use the function depends on where it comes from func_module = str(getattr(func, "__module__", "")) # Check for categorical plots without order information if func_module == "seaborn.categorical": if "order" not in kwargs: warning = ("Using the {} function without specifying " "`order` is likely to produce an incorrect " "plot.".format(func.__name__)) warnings.warn(warning) if len(args) == 3 and "hue_order" not in kwargs: warning = ("Using the {} function without specifying " "`hue_order` is likely to produce an incorrect " "plot.".format(func.__name__)) warnings.warn(warning) # Iterate over the data subsets for (row_i, col_j, hue_k), data_ijk in self.facet_data(): # If this subset is null, move on if not data_ijk.values.size: continue # Get the current axis modify_state = not func_module.startswith("seaborn") ax = self.facet_axis(row_i, col_j, modify_state) # Decide what color to plot with kwargs["color"] = self._facet_color(hue_k, kw_color) # Insert the other hue aesthetics if appropriate for kw, val_list in self.hue_kws.items(): kwargs[kw] = val_list[hue_k] # Insert a label in the keyword arguments for the legend if self._hue_var is not None: kwargs["label"] = utils.to_utf8(self.hue_names[hue_k]) # Get the actual data we are going to plot with plot_data = data_ijk[list(args)] if self._dropna: plot_data = plot_data.dropna() plot_args = [v for k, v in plot_data.items()] # Some matplotlib functions don't handle pandas objects correctly if func_module.startswith("matplotlib"): plot_args = [v.values for v in plot_args] # Draw the plot self._facet_plot(func, ax, plot_args, kwargs) # Finalize the annotations and layout self._finalize_grid(args[:2]) return self def map_dataframe(self, func, *args, **kwargs): """Like ``.map`` but passes args as strings and inserts data in kwargs. This method is suitable for plotting with functions that accept a long-form DataFrame as a `data` keyword argument and access the data in that DataFrame using string variable names. Parameters ---------- func : callable A plotting function that takes data and keyword arguments. Unlike the `map` method, a function used here must "understand" Pandas objects. It also must plot to the currently active matplotlib Axes and take a `color` keyword argument. If faceting on the `hue` dimension, it must also take a `label` keyword argument. args : strings Column names in self.data that identify variables with data to plot. The data for each variable is passed to `func` in the order the variables are specified in the call. kwargs : keyword arguments All keyword arguments are passed to the plotting function. Returns ------- self : object Returns self. """ # If color was a keyword argument, grab it here kw_color = kwargs.pop("color", None) # Iterate over the data subsets for (row_i, col_j, hue_k), data_ijk in self.facet_data(): # If this subset is null, move on if not data_ijk.values.size: continue # Get the current axis modify_state = not str(func.__module__).startswith("seaborn") ax = self.facet_axis(row_i, col_j, modify_state) # Decide what color to plot with kwargs["color"] = self._facet_color(hue_k, kw_color) # Insert the other hue aesthetics if appropriate for kw, val_list in self.hue_kws.items(): kwargs[kw] = val_list[hue_k] # Insert a label in the keyword arguments for the legend if self._hue_var is not None: kwargs["label"] = self.hue_names[hue_k] # Stick the facet dataframe into the kwargs if self._dropna: data_ijk = data_ijk.dropna() kwargs["data"] = data_ijk # Draw the plot self._facet_plot(func, ax, args, kwargs) # For axis labels, prefer to use positional args for backcompat # but also extract the x/y kwargs and use if no corresponding arg axis_labels = [kwargs.get("x", None), kwargs.get("y", None)] for i, val in enumerate(args[:2]): axis_labels[i] = val self._finalize_grid(axis_labels) return self def _facet_color(self, hue_index, kw_color): color = self._colors[hue_index] if kw_color is not None: return kw_color elif color is not None: return color def _facet_plot(self, func, ax, plot_args, plot_kwargs): # Draw the plot if str(func.__module__).startswith("seaborn"): plot_kwargs = plot_kwargs.copy() semantics = ["x", "y", "hue", "size", "style"] for key, val in zip(semantics, plot_args): plot_kwargs[key] = val plot_args = [] plot_kwargs["ax"] = ax func(*plot_args, **plot_kwargs) # Sort out the supporting information self._update_legend_data(ax) def _finalize_grid(self, axlabels): """Finalize the annotations and layout.""" self.set_axis_labels(*axlabels) self.tight_layout() def facet_axis(self, row_i, col_j, modify_state=True): """Make the axis identified by these indices active and return it.""" # Calculate the actual indices of the axes to plot on if self._col_wrap is not None: ax = self.axes.flat[col_j] else: ax = self.axes[row_i, col_j] # Get a reference to the axes object we want, and make it active if modify_state: plt.sca(ax) return ax def despine(self, **kwargs): """Remove axis spines from the facets.""" utils.despine(self._figure, **kwargs) return self def set_axis_labels(self, x_var=None, y_var=None, clear_inner=True, **kwargs): """Set axis labels on the left column and bottom row of the grid.""" if x_var is not None: self._x_var = x_var self.set_xlabels(x_var, clear_inner=clear_inner, **kwargs) if y_var is not None: self._y_var = y_var self.set_ylabels(y_var, clear_inner=clear_inner, **kwargs) return self def set_xlabels(self, label=None, clear_inner=True, **kwargs): """Label the x axis on the bottom row of the grid.""" if label is None: label = self._x_var for ax in self._bottom_axes: ax.set_xlabel(label, **kwargs) if clear_inner: for ax in self._not_bottom_axes: ax.set_xlabel("") return self def set_ylabels(self, label=None, clear_inner=True, **kwargs): """Label the y axis on the left column of the grid.""" if label is None: label = self._y_var for ax in self._left_axes: ax.set_ylabel(label, **kwargs) if clear_inner: for ax in self._not_left_axes: ax.set_ylabel("") return self def set_xticklabels(self, labels=None, step=None, **kwargs): """Set x axis tick labels of the grid.""" for ax in self.axes.flat: curr_ticks = ax.get_xticks() ax.set_xticks(curr_ticks) if labels is None: curr_labels = [label.get_text() for label in ax.get_xticklabels()] if step is not None: xticks = ax.get_xticks()[::step] curr_labels = curr_labels[::step] ax.set_xticks(xticks) ax.set_xticklabels(curr_labels, **kwargs) else: ax.set_xticklabels(labels, **kwargs) return self def set_yticklabels(self, labels=None, **kwargs): """Set y axis tick labels on the left column of the grid.""" for ax in self.axes.flat: curr_ticks = ax.get_yticks() ax.set_yticks(curr_ticks) if labels is None: curr_labels = [label.get_text() for label in ax.get_yticklabels()] ax.set_yticklabels(curr_labels, **kwargs) else: ax.set_yticklabels(labels, **kwargs) return self def set_titles(self, template=None, row_template=None, col_template=None, **kwargs): """Draw titles either above each facet or on the grid margins. Parameters ---------- template : string Template for all titles with the formatting keys {col_var} and {col_name} (if using a `col` faceting variable) and/or {row_var} and {row_name} (if using a `row` faceting variable). row_template: Template for the row variable when titles are drawn on the grid margins. Must have {row_var} and {row_name} formatting keys. col_template: Template for the column variable when titles are drawn on the grid margins. Must have {col_var} and {col_name} formatting keys. Returns ------- self: object Returns self. """ args = dict(row_var=self._row_var, col_var=self._col_var) kwargs["size"] = kwargs.pop("size", mpl.rcParams["axes.labelsize"]) # Establish default templates if row_template is None: row_template = "{row_var} = {row_name}" if col_template is None: col_template = "{col_var} = {col_name}" if template is None: if self._row_var is None: template = col_template elif self._col_var is None: template = row_template else: template = " | ".join([row_template, col_template]) row_template = utils.to_utf8(row_template) col_template = utils.to_utf8(col_template) template = utils.to_utf8(template) if self._margin_titles: # Remove any existing title texts for text in self._margin_titles_texts: text.remove() self._margin_titles_texts = [] if self.row_names is not None: # Draw the row titles on the right edge of the grid for i, row_name in enumerate(self.row_names): ax = self.axes[i, -1] args.update(dict(row_name=row_name)) title = row_template.format(**args) text = ax.annotate( title, xy=(1.02, .5), xycoords="axes fraction", rotation=270, ha="left", va="center", **kwargs ) self._margin_titles_texts.append(text) if self.col_names is not None: # Draw the column titles as normal titles for j, col_name in enumerate(self.col_names): args.update(dict(col_name=col_name)) title = col_template.format(**args) self.axes[0, j].set_title(title, **kwargs) return self # Otherwise title each facet with all the necessary information if (self._row_var is not None) and (self._col_var is not None): for i, row_name in enumerate(self.row_names): for j, col_name in enumerate(self.col_names): args.update(dict(row_name=row_name, col_name=col_name)) title = template.format(**args) self.axes[i, j].set_title(title, **kwargs) elif self.row_names is not None and len(self.row_names): for i, row_name in enumerate(self.row_names): args.update(dict(row_name=row_name)) title = template.format(**args) self.axes[i, 0].set_title(title, **kwargs) elif self.col_names is not None and len(self.col_names): for i, col_name in enumerate(self.col_names): args.update(dict(col_name=col_name)) title = template.format(**args) # Index the flat array so col_wrap works self.axes.flat[i].set_title(title, **kwargs) return self def refline(self, *, x=None, y=None, color='.5', linestyle='--', **line_kws): """Add a reference line(s) to each facet. Parameters ---------- x, y : numeric Value(s) to draw the line(s) at. color : :mod:`matplotlib color ` Specifies the color of the reference line(s). Pass ``color=None`` to use ``hue`` mapping. linestyle : str Specifies the style of the reference line(s). line_kws : key, value mappings Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline` when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y`` is not None. Returns ------- :class:`FacetGrid` instance Returns ``self`` for easy method chaining. """ line_kws['color'] = color line_kws['linestyle'] = linestyle if x is not None: self.map(plt.axvline, x=x, **line_kws) if y is not None: self.map(plt.axhline, y=y, **line_kws) return self # ------ Properties that are part of the public API and documented by Sphinx @property def axes(self): """An array of the :class:`matplotlib.axes.Axes` objects in the grid.""" return self._axes @property def ax(self): """The :class:`matplotlib.axes.Axes` when no faceting variables are assigned.""" if self.axes.shape == (1, 1): return self.axes[0, 0] else: err = ( "Use the `.axes` attribute when facet variables are assigned." ) raise AttributeError(err) @property def axes_dict(self): """A mapping of facet names to corresponding :class:`matplotlib.axes.Axes`. If only one of ``row`` or ``col`` is assigned, each key is a string representing a level of that variable. If both facet dimensions are assigned, each key is a ``({row_level}, {col_level})`` tuple. """ return self._axes_dict # ------ Private properties, that require some computation to get @property def _inner_axes(self): """Return a flat array of the inner axes.""" if self._col_wrap is None: return self.axes[:-1, 1:].flat else: axes = [] n_empty = self._nrow * self._ncol - self._n_facets for i, ax in enumerate(self.axes): append = ( i % self._ncol and i < (self._ncol * (self._nrow - 1)) and i < (self._ncol * (self._nrow - 1) - n_empty) ) if append: axes.append(ax) return np.array(axes, object).flat @property def _left_axes(self): """Return a flat array of the left column of axes.""" if self._col_wrap is None: return self.axes[:, 0].flat else: axes = [] for i, ax in enumerate(self.axes): if not i % self._ncol: axes.append(ax) return np.array(axes, object).flat @property def _not_left_axes(self): """Return a flat array of axes that aren't on the left column.""" if self._col_wrap is None: return self.axes[:, 1:].flat else: axes = [] for i, ax in enumerate(self.axes): if i % self._ncol: axes.append(ax) return np.array(axes, object).flat @property def _bottom_axes(self): """Return a flat array of the bottom row of axes.""" if self._col_wrap is None: return self.axes[-1, :].flat else: axes = [] n_empty = self._nrow * self._ncol - self._n_facets for i, ax in enumerate(self.axes): append = ( i >= (self._ncol * (self._nrow - 1)) or i >= (self._ncol * (self._nrow - 1) - n_empty) ) if append: axes.append(ax) return np.array(axes, object).flat @property def _not_bottom_axes(self): """Return a flat array of axes that aren't on the bottom row.""" if self._col_wrap is None: return self.axes[:-1, :].flat else: axes = [] n_empty = self._nrow * self._ncol - self._n_facets for i, ax in enumerate(self.axes): append = ( i < (self._ncol * (self._nrow - 1)) and i < (self._ncol * (self._nrow - 1) - n_empty) ) if append: axes.append(ax) return np.array(axes, object).flat class PairGrid(Grid): """Subplot grid for plotting pairwise relationships in a dataset. This object maps each variable in a dataset onto a column and row in a grid of multiple axes. Different axes-level plotting functions can be used to draw bivariate plots in the upper and lower triangles, and the marginal distribution of each variable can be shown on the diagonal. Several different common plots can be generated in a single line using :func:`pairplot`. Use :class:`PairGrid` when you need more flexibility. See the :ref:`tutorial ` for more information. """ def __init__( self, data, *, hue=None, vars=None, x_vars=None, y_vars=None, hue_order=None, palette=None, hue_kws=None, corner=False, diag_sharey=True, height=2.5, aspect=1, layout_pad=.5, despine=True, dropna=False, ): """Initialize the plot figure and PairGrid object. Parameters ---------- data : DataFrame Tidy (long-form) dataframe where each column is a variable and each row is an observation. hue : string (variable name) Variable in ``data`` to map plot aspects to different colors. This variable will be excluded from the default x and y variables. vars : list of variable names Variables within ``data`` to use, otherwise use every column with a numeric datatype. {x, y}_vars : lists of variable names Variables within ``data`` to use separately for the rows and columns of the figure; i.e. to make a non-square plot. hue_order : list of strings Order for the levels of the hue variable in the palette palette : dict or seaborn color palette Set of colors for mapping the ``hue`` variable. If a dict, keys should be values in the ``hue`` variable. hue_kws : dictionary of param -> list of values mapping Other keyword arguments to insert into the plotting call to let other plot attributes vary across levels of the hue variable (e.g. the markers in a scatterplot). corner : bool If True, don't add axes to the upper (off-diagonal) triangle of the grid, making this a "corner" plot. height : scalar Height (in inches) of each facet. aspect : scalar Aspect * height gives the width (in inches) of each facet. layout_pad : scalar Padding between axes; passed to ``fig.tight_layout``. despine : boolean Remove the top and right spines from the plots. dropna : boolean Drop missing values from the data before plotting. See Also -------- pairplot : Easily drawing common uses of :class:`PairGrid`. FacetGrid : Subplot grid for plotting conditional relationships. Examples -------- .. include:: ../docstrings/PairGrid.rst """ super().__init__() data = handle_data_source(data) # Sort out the variables that define the grid numeric_cols = self._find_numeric_cols(data) if hue in numeric_cols: numeric_cols.remove(hue) if vars is not None: x_vars = list(vars) y_vars = list(vars) if x_vars is None: x_vars = numeric_cols if y_vars is None: y_vars = numeric_cols if np.isscalar(x_vars): x_vars = [x_vars] if np.isscalar(y_vars): y_vars = [y_vars] self.x_vars = x_vars = list(x_vars) self.y_vars = y_vars = list(y_vars) self.square_grid = self.x_vars == self.y_vars if not x_vars: raise ValueError("No variables found for grid columns.") if not y_vars: raise ValueError("No variables found for grid rows.") # Create the figure and the array of subplots figsize = len(x_vars) * height * aspect, len(y_vars) * height with _disable_autolayout(): fig = plt.figure(figsize=figsize) axes = fig.subplots(len(y_vars), len(x_vars), sharex="col", sharey="row", squeeze=False) # Possibly remove upper axes to make a corner grid # Note: setting up the axes is usually the most time-intensive part # of using the PairGrid. We are foregoing the speed improvement that # we would get by just not setting up the hidden axes so that we can # avoid implementing fig.subplots ourselves. But worth thinking about. self._corner = corner if corner: hide_indices = np.triu_indices_from(axes, 1) for i, j in zip(*hide_indices): axes[i, j].remove() axes[i, j] = None self._figure = fig self.axes = axes self.data = data # Save what we are going to do with the diagonal self.diag_sharey = diag_sharey self.diag_vars = None self.diag_axes = None self._dropna = dropna # Label the axes self._add_axis_labels() # Sort out the hue variable self._hue_var = hue if hue is None: self.hue_names = hue_order = ["_nolegend_"] self.hue_vals = pd.Series(["_nolegend_"] * len(data), index=data.index) else: # We need hue_order and hue_names because the former is used to control # the order of drawing and the latter is used to control the order of # the legend. hue_names can become string-typed while hue_order must # retain the type of the input data. This is messy but results from # the fact that PairGrid can implement the hue-mapping logic itself # (and was originally written exclusively that way) but now can delegate # to the axes-level functions, while always handling legend creation. # See GH2307 hue_names = hue_order = categorical_order(data[hue], hue_order) if dropna: # Filter NA from the list of unique hue names hue_names = list(filter(pd.notnull, hue_names)) self.hue_names = hue_names self.hue_vals = data[hue] # Additional dict of kwarg -> list of values for mapping the hue var self.hue_kws = hue_kws if hue_kws is not None else {} self._orig_palette = palette self._hue_order = hue_order self.palette = self._get_palette(data, hue, hue_order, palette) self._legend_data = {} # Make the plot look nice for ax in axes[:-1, :].flat: if ax is None: continue for label in ax.get_xticklabels(): label.set_visible(False) ax.xaxis.offsetText.set_visible(False) ax.xaxis.label.set_visible(False) for ax in axes[:, 1:].flat: if ax is None: continue for label in ax.get_yticklabels(): label.set_visible(False) ax.yaxis.offsetText.set_visible(False) ax.yaxis.label.set_visible(False) self._tight_layout_rect = [.01, .01, .99, .99] self._tight_layout_pad = layout_pad self._despine = despine if despine: utils.despine(fig=fig) self.tight_layout(pad=layout_pad) def map(self, func, **kwargs): """Plot with the same function in every subplot. Parameters ---------- func : callable plotting function Must take x, y arrays as positional arguments and draw onto the "currently active" matplotlib Axes. Also needs to accept kwargs called ``color`` and ``label``. """ row_indices, col_indices = np.indices(self.axes.shape) indices = zip(row_indices.flat, col_indices.flat) self._map_bivariate(func, indices, **kwargs) return self def map_lower(self, func, **kwargs): """Plot with a bivariate function on the lower diagonal subplots. Parameters ---------- func : callable plotting function Must take x, y arrays as positional arguments and draw onto the "currently active" matplotlib Axes. Also needs to accept kwargs called ``color`` and ``label``. """ indices = zip(*np.tril_indices_from(self.axes, -1)) self._map_bivariate(func, indices, **kwargs) return self def map_upper(self, func, **kwargs): """Plot with a bivariate function on the upper diagonal subplots. Parameters ---------- func : callable plotting function Must take x, y arrays as positional arguments and draw onto the "currently active" matplotlib Axes. Also needs to accept kwargs called ``color`` and ``label``. """ indices = zip(*np.triu_indices_from(self.axes, 1)) self._map_bivariate(func, indices, **kwargs) return self def map_offdiag(self, func, **kwargs): """Plot with a bivariate function on the off-diagonal subplots. Parameters ---------- func : callable plotting function Must take x, y arrays as positional arguments and draw onto the "currently active" matplotlib Axes. Also needs to accept kwargs called ``color`` and ``label``. """ if self.square_grid: self.map_lower(func, **kwargs) if not self._corner: self.map_upper(func, **kwargs) else: indices = [] for i, (y_var) in enumerate(self.y_vars): for j, (x_var) in enumerate(self.x_vars): if x_var != y_var: indices.append((i, j)) self._map_bivariate(func, indices, **kwargs) return self def map_diag(self, func, **kwargs): """Plot with a univariate function on each diagonal subplot. Parameters ---------- func : callable plotting function Must take an x array as a positional argument and draw onto the "currently active" matplotlib Axes. Also needs to accept kwargs called ``color`` and ``label``. """ # Add special diagonal axes for the univariate plot if self.diag_axes is None: diag_vars = [] diag_axes = [] for i, y_var in enumerate(self.y_vars): for j, x_var in enumerate(self.x_vars): if x_var == y_var: # Make the density axes diag_vars.append(x_var) ax = self.axes[i, j] diag_ax = ax.twinx() diag_ax.set_axis_off() diag_axes.append(diag_ax) # Work around matplotlib bug # https://github.com/matplotlib/matplotlib/issues/15188 if not plt.rcParams.get("ytick.left", True): for tick in ax.yaxis.majorTicks: tick.tick1line.set_visible(False) # Remove main y axis from density axes in a corner plot if self._corner: ax.yaxis.set_visible(False) if self._despine: utils.despine(ax=ax, left=True) # TODO add optional density ticks (on the right) # when drawing a corner plot? if self.diag_sharey and diag_axes: for ax in diag_axes[1:]: share_axis(diag_axes[0], ax, "y") self.diag_vars = diag_vars self.diag_axes = diag_axes if "hue" not in signature(func).parameters: return self._map_diag_iter_hue(func, **kwargs) # Loop over diagonal variables and axes, making one plot in each for var, ax in zip(self.diag_vars, self.diag_axes): plot_kwargs = kwargs.copy() if str(func.__module__).startswith("seaborn"): plot_kwargs["ax"] = ax else: plt.sca(ax) vector = self.data[var] if self._hue_var is not None: hue = self.data[self._hue_var] else: hue = None if self._dropna: not_na = vector.notna() if hue is not None: not_na &= hue.notna() vector = vector[not_na] if hue is not None: hue = hue[not_na] plot_kwargs.setdefault("hue", hue) plot_kwargs.setdefault("hue_order", self._hue_order) plot_kwargs.setdefault("palette", self._orig_palette) func(x=vector, **plot_kwargs) ax.legend_ = None self._add_axis_labels() return self def _map_diag_iter_hue(self, func, **kwargs): """Put marginal plot on each diagonal axes, iterating over hue.""" # Plot on each of the diagonal axes fixed_color = kwargs.pop("color", None) for var, ax in zip(self.diag_vars, self.diag_axes): hue_grouped = self.data[var].groupby(self.hue_vals, observed=True) plot_kwargs = kwargs.copy() if str(func.__module__).startswith("seaborn"): plot_kwargs["ax"] = ax else: plt.sca(ax) for k, label_k in enumerate(self._hue_order): # Attempt to get data for this level, allowing for empty try: data_k = hue_grouped.get_group(label_k) except KeyError: data_k = pd.Series([], dtype=float) if fixed_color is None: color = self.palette[k] else: color = fixed_color if self._dropna: data_k = utils.remove_na(data_k) if str(func.__module__).startswith("seaborn"): func(x=data_k, label=label_k, color=color, **plot_kwargs) else: func(data_k, label=label_k, color=color, **plot_kwargs) self._add_axis_labels() return self def _map_bivariate(self, func, indices, **kwargs): """Draw a bivariate plot on the indicated axes.""" # This is a hack to handle the fact that new distribution plots don't add # their artists onto the axes. This is probably superior in general, but # we'll need a better way to handle it in the axisgrid functions. from .distributions import histplot, kdeplot if func is histplot or func is kdeplot: self._extract_legend_handles = True kws = kwargs.copy() # Use copy as we insert other kwargs for i, j in indices: x_var = self.x_vars[j] y_var = self.y_vars[i] ax = self.axes[i, j] if ax is None: # i.e. we are in corner mode continue self._plot_bivariate(x_var, y_var, ax, func, **kws) self._add_axis_labels() if "hue" in signature(func).parameters: self.hue_names = list(self._legend_data) def _plot_bivariate(self, x_var, y_var, ax, func, **kwargs): """Draw a bivariate plot on the specified axes.""" if "hue" not in signature(func).parameters: self._plot_bivariate_iter_hue(x_var, y_var, ax, func, **kwargs) return kwargs = kwargs.copy() if str(func.__module__).startswith("seaborn"): kwargs["ax"] = ax else: plt.sca(ax) if x_var == y_var: axes_vars = [x_var] else: axes_vars = [x_var, y_var] if self._hue_var is not None and self._hue_var not in axes_vars: axes_vars.append(self._hue_var) data = self.data[axes_vars] if self._dropna: data = data.dropna() x = data[x_var] y = data[y_var] if self._hue_var is None: hue = None else: hue = data.get(self._hue_var) if "hue" not in kwargs: kwargs.update({ "hue": hue, "hue_order": self._hue_order, "palette": self._orig_palette, }) func(x=x, y=y, **kwargs) self._update_legend_data(ax) def _plot_bivariate_iter_hue(self, x_var, y_var, ax, func, **kwargs): """Draw a bivariate plot while iterating over hue subsets.""" kwargs = kwargs.copy() if str(func.__module__).startswith("seaborn"): kwargs["ax"] = ax else: plt.sca(ax) if x_var == y_var: axes_vars = [x_var] else: axes_vars = [x_var, y_var] hue_grouped = self.data.groupby(self.hue_vals, observed=True) for k, label_k in enumerate(self._hue_order): kws = kwargs.copy() # Attempt to get data for this level, allowing for empty try: data_k = hue_grouped.get_group(label_k) except KeyError: data_k = pd.DataFrame(columns=axes_vars, dtype=float) if self._dropna: data_k = data_k[axes_vars].dropna() x = data_k[x_var] y = data_k[y_var] for kw, val_list in self.hue_kws.items(): kws[kw] = val_list[k] kws.setdefault("color", self.palette[k]) if self._hue_var is not None: kws["label"] = label_k if str(func.__module__).startswith("seaborn"): func(x=x, y=y, **kws) else: func(x, y, **kws) self._update_legend_data(ax) def _add_axis_labels(self): """Add labels to the left and bottom Axes.""" for ax, label in zip(self.axes[-1, :], self.x_vars): ax.set_xlabel(label) for ax, label in zip(self.axes[:, 0], self.y_vars): ax.set_ylabel(label) def _find_numeric_cols(self, data): """Find which variables in a DataFrame are numeric.""" numeric_cols = [] for col in data: if variable_type(data[col]) == "numeric": numeric_cols.append(col) return numeric_cols class JointGrid(_BaseGrid): """Grid for drawing a bivariate plot with marginal univariate plots. Many plots can be drawn by using the figure-level interface :func:`jointplot`. Use this class directly when you need more flexibility. """ def __init__( self, data=None, *, x=None, y=None, hue=None, height=6, ratio=5, space=.2, palette=None, hue_order=None, hue_norm=None, dropna=False, xlim=None, ylim=None, marginal_ticks=False, ): # Set up the subplot grid f = plt.figure(figsize=(height, height)) gs = plt.GridSpec(ratio + 1, ratio + 1) ax_joint = f.add_subplot(gs[1:, :-1]) ax_marg_x = f.add_subplot(gs[0, :-1], sharex=ax_joint) ax_marg_y = f.add_subplot(gs[1:, -1], sharey=ax_joint) self._figure = f self.ax_joint = ax_joint self.ax_marg_x = ax_marg_x self.ax_marg_y = ax_marg_y # Turn off tick visibility for the measure axis on the marginal plots plt.setp(ax_marg_x.get_xticklabels(), visible=False) plt.setp(ax_marg_y.get_yticklabels(), visible=False) plt.setp(ax_marg_x.get_xticklabels(minor=True), visible=False) plt.setp(ax_marg_y.get_yticklabels(minor=True), visible=False) # Turn off the ticks on the density axis for the marginal plots if not marginal_ticks: plt.setp(ax_marg_x.yaxis.get_majorticklines(), visible=False) plt.setp(ax_marg_x.yaxis.get_minorticklines(), visible=False) plt.setp(ax_marg_y.xaxis.get_majorticklines(), visible=False) plt.setp(ax_marg_y.xaxis.get_minorticklines(), visible=False) plt.setp(ax_marg_x.get_yticklabels(), visible=False) plt.setp(ax_marg_y.get_xticklabels(), visible=False) plt.setp(ax_marg_x.get_yticklabels(minor=True), visible=False) plt.setp(ax_marg_y.get_xticklabels(minor=True), visible=False) ax_marg_x.yaxis.grid(False) ax_marg_y.xaxis.grid(False) # Process the input variables p = VectorPlotter(data=data, variables=dict(x=x, y=y, hue=hue)) plot_data = p.plot_data.loc[:, p.plot_data.notna().any()] # Possibly drop NA if dropna: plot_data = plot_data.dropna() def get_var(var): vector = plot_data.get(var, None) if vector is not None: vector = vector.rename(p.variables.get(var, None)) return vector self.x = get_var("x") self.y = get_var("y") self.hue = get_var("hue") for axis in "xy": name = p.variables.get(axis, None) if name is not None: getattr(ax_joint, f"set_{axis}label")(name) if xlim is not None: ax_joint.set_xlim(xlim) if ylim is not None: ax_joint.set_ylim(ylim) # Store the semantic mapping parameters for axes-level functions self._hue_params = dict(palette=palette, hue_order=hue_order, hue_norm=hue_norm) # Make the grid look nice utils.despine(f) if not marginal_ticks: utils.despine(ax=ax_marg_x, left=True) utils.despine(ax=ax_marg_y, bottom=True) for axes in [ax_marg_x, ax_marg_y]: for axis in [axes.xaxis, axes.yaxis]: axis.label.set_visible(False) f.tight_layout() f.subplots_adjust(hspace=space, wspace=space) def _inject_kwargs(self, func, kws, params): """Add params to kws if they are accepted by func.""" func_params = signature(func).parameters for key, val in params.items(): if key in func_params: kws.setdefault(key, val) def plot(self, joint_func, marginal_func, **kwargs): """Draw the plot by passing functions for joint and marginal axes. This method passes the ``kwargs`` dictionary to both functions. If you need more control, call :meth:`JointGrid.plot_joint` and :meth:`JointGrid.plot_marginals` directly with specific parameters. Parameters ---------- joint_func, marginal_func : callables Functions to draw the bivariate and univariate plots. See methods referenced above for information about the required characteristics of these functions. kwargs Additional keyword arguments are passed to both functions. Returns ------- :class:`JointGrid` instance Returns ``self`` for easy method chaining. """ self.plot_marginals(marginal_func, **kwargs) self.plot_joint(joint_func, **kwargs) return self def plot_joint(self, func, **kwargs): """Draw a bivariate plot on the joint axes of the grid. Parameters ---------- func : plotting callable If a seaborn function, it should accept ``x`` and ``y``. Otherwise, it must accept ``x`` and ``y`` vectors of data as the first two positional arguments, and it must plot on the "current" axes. If ``hue`` was defined in the class constructor, the function must accept ``hue`` as a parameter. kwargs Keyword argument are passed to the plotting function. Returns ------- :class:`JointGrid` instance Returns ``self`` for easy method chaining. """ kwargs = kwargs.copy() if str(func.__module__).startswith("seaborn"): kwargs["ax"] = self.ax_joint else: plt.sca(self.ax_joint) if self.hue is not None: kwargs["hue"] = self.hue self._inject_kwargs(func, kwargs, self._hue_params) if str(func.__module__).startswith("seaborn"): func(x=self.x, y=self.y, **kwargs) else: func(self.x, self.y, **kwargs) return self def plot_marginals(self, func, **kwargs): """Draw univariate plots on each marginal axes. Parameters ---------- func : plotting callable If a seaborn function, it should accept ``x`` and ``y`` and plot when only one of them is defined. Otherwise, it must accept a vector of data as the first positional argument and determine its orientation using the ``vertical`` parameter, and it must plot on the "current" axes. If ``hue`` was defined in the class constructor, it must accept ``hue`` as a parameter. kwargs Keyword argument are passed to the plotting function. Returns ------- :class:`JointGrid` instance Returns ``self`` for easy method chaining. """ seaborn_func = ( str(func.__module__).startswith("seaborn") # deprecated distplot has a legacy API, special case it and not func.__name__ == "distplot" ) func_params = signature(func).parameters kwargs = kwargs.copy() if self.hue is not None: kwargs["hue"] = self.hue self._inject_kwargs(func, kwargs, self._hue_params) if "legend" in func_params: kwargs.setdefault("legend", False) if "orientation" in func_params: # e.g. plt.hist orient_kw_x = {"orientation": "vertical"} orient_kw_y = {"orientation": "horizontal"} elif "vertical" in func_params: # e.g. sns.distplot (also how did this get backwards?) orient_kw_x = {"vertical": False} orient_kw_y = {"vertical": True} if seaborn_func: func(x=self.x, ax=self.ax_marg_x, **kwargs) else: plt.sca(self.ax_marg_x) func(self.x, **orient_kw_x, **kwargs) if seaborn_func: func(y=self.y, ax=self.ax_marg_y, **kwargs) else: plt.sca(self.ax_marg_y) func(self.y, **orient_kw_y, **kwargs) self.ax_marg_x.yaxis.get_label().set_visible(False) self.ax_marg_y.xaxis.get_label().set_visible(False) return self def refline( self, *, x=None, y=None, joint=True, marginal=True, color='.5', linestyle='--', **line_kws ): """Add a reference line(s) to joint and/or marginal axes. Parameters ---------- x, y : numeric Value(s) to draw the line(s) at. joint, marginal : bools Whether to add the reference line(s) to the joint/marginal axes. color : :mod:`matplotlib color ` Specifies the color of the reference line(s). linestyle : str Specifies the style of the reference line(s). line_kws : key, value mappings Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.axvline` when ``x`` is not None and :meth:`matplotlib.axes.Axes.axhline` when ``y`` is not None. Returns ------- :class:`JointGrid` instance Returns ``self`` for easy method chaining. """ line_kws['color'] = color line_kws['linestyle'] = linestyle if x is not None: if joint: self.ax_joint.axvline(x, **line_kws) if marginal: self.ax_marg_x.axvline(x, **line_kws) if y is not None: if joint: self.ax_joint.axhline(y, **line_kws) if marginal: self.ax_marg_y.axhline(y, **line_kws) return self def set_axis_labels(self, xlabel="", ylabel="", **kwargs): """Set axis labels on the bivariate axes. Parameters ---------- xlabel, ylabel : strings Label names for the x and y variables. kwargs : key, value mappings Other keyword arguments are passed to the following functions: - :meth:`matplotlib.axes.Axes.set_xlabel` - :meth:`matplotlib.axes.Axes.set_ylabel` Returns ------- :class:`JointGrid` instance Returns ``self`` for easy method chaining. """ self.ax_joint.set_xlabel(xlabel, **kwargs) self.ax_joint.set_ylabel(ylabel, **kwargs) return self JointGrid.__init__.__doc__ = """\ Set up the grid of subplots and store data internally for easy plotting. Parameters ---------- {params.core.data} {params.core.xy} height : number Size of each side of the figure in inches (it will be square). ratio : number Ratio of joint axes height to marginal axes height. space : number Space between the joint and marginal axes dropna : bool If True, remove missing observations before plotting. {{x, y}}lim : pairs of numbers Set axis limits to these values before plotting. marginal_ticks : bool If False, suppress ticks on the count/density axis of the marginal plots. {params.core.hue} Note: unlike in :class:`FacetGrid` or :class:`PairGrid`, the axes-level functions must support ``hue`` to use it in :class:`JointGrid`. {params.core.palette} {params.core.hue_order} {params.core.hue_norm} See Also -------- {seealso.jointplot} {seealso.pairgrid} {seealso.pairplot} Examples -------- .. include:: ../docstrings/JointGrid.rst """.format( params=_param_docs, seealso=_core_docs["seealso"], ) def pairplot( data, *, hue=None, hue_order=None, palette=None, vars=None, x_vars=None, y_vars=None, kind="scatter", diag_kind="auto", markers=None, height=2.5, aspect=1, corner=False, dropna=False, plot_kws=None, diag_kws=None, grid_kws=None, size=None, ): """Plot pairwise relationships in a dataset. By default, this function will create a grid of Axes such that each numeric variable in ``data`` will by shared across the y-axes across a single row and the x-axes across a single column. The diagonal plots are treated differently: a univariate distribution plot is drawn to show the marginal distribution of the data in each column. It is also possible to show a subset of variables or plot different variables on the rows and columns. This is a high-level interface for :class:`PairGrid` that is intended to make it easy to draw a few common styles. You should use :class:`PairGrid` directly if you need more flexibility. Parameters ---------- data : `pandas.DataFrame` Tidy (long-form) dataframe where each column is a variable and each row is an observation. hue : name of variable in ``data`` Variable in ``data`` to map plot aspects to different colors. hue_order : list of strings Order for the levels of the hue variable in the palette palette : dict or seaborn color palette Set of colors for mapping the ``hue`` variable. If a dict, keys should be values in the ``hue`` variable. vars : list of variable names Variables within ``data`` to use, otherwise use every column with a numeric datatype. {x, y}_vars : lists of variable names Variables within ``data`` to use separately for the rows and columns of the figure; i.e. to make a non-square plot. kind : {'scatter', 'kde', 'hist', 'reg'} Kind of plot to make. diag_kind : {'auto', 'hist', 'kde', None} Kind of plot for the diagonal subplots. If 'auto', choose based on whether or not ``hue`` is used. markers : single matplotlib marker code or list Either the marker to use for all scatterplot points or a list of markers with a length the same as the number of levels in the hue variable so that differently colored points will also have different scatterplot markers. height : scalar Height (in inches) of each facet. aspect : scalar Aspect * height gives the width (in inches) of each facet. corner : bool If True, don't add axes to the upper (off-diagonal) triangle of the grid, making this a "corner" plot. dropna : boolean Drop missing values from the data before plotting. {plot, diag, grid}_kws : dicts Dictionaries of keyword arguments. ``plot_kws`` are passed to the bivariate plotting function, ``diag_kws`` are passed to the univariate plotting function, and ``grid_kws`` are passed to the :class:`PairGrid` constructor. Returns ------- grid : :class:`PairGrid` Returns the underlying :class:`PairGrid` instance for further tweaking. See Also -------- PairGrid : Subplot grid for more flexible plotting of pairwise relationships. JointGrid : Grid for plotting joint and marginal distributions of two variables. Examples -------- .. include:: ../docstrings/pairplot.rst """ # Avoid circular import from .distributions import histplot, kdeplot # Handle deprecations if size is not None: height = size msg = ("The `size` parameter has been renamed to `height`; " "please update your code.") warnings.warn(msg, UserWarning) if not isinstance(data, pd.DataFrame): raise TypeError( f"'data' must be pandas DataFrame object, not: {type(data)}") plot_kws = {} if plot_kws is None else plot_kws.copy() diag_kws = {} if diag_kws is None else diag_kws.copy() grid_kws = {} if grid_kws is None else grid_kws.copy() # Resolve "auto" diag kind if diag_kind == "auto": if hue is None: diag_kind = "kde" if kind == "kde" else "hist" else: diag_kind = "hist" if kind == "hist" else "kde" # Set up the PairGrid grid_kws.setdefault("diag_sharey", diag_kind == "hist") grid = PairGrid(data, vars=vars, x_vars=x_vars, y_vars=y_vars, hue=hue, hue_order=hue_order, palette=palette, corner=corner, height=height, aspect=aspect, dropna=dropna, **grid_kws) # Add the markers here as PairGrid has figured out how many levels of the # hue variable are needed and we don't want to duplicate that process if markers is not None: if kind == "reg": # Needed until regplot supports style if grid.hue_names is None: n_markers = 1 else: n_markers = len(grid.hue_names) if not isinstance(markers, list): markers = [markers] * n_markers if len(markers) != n_markers: raise ValueError("markers must be a singleton or a list of " "markers for each level of the hue variable") grid.hue_kws = {"marker": markers} elif kind == "scatter": if isinstance(markers, str): plot_kws["marker"] = markers elif hue is not None: plot_kws["style"] = data[hue] plot_kws["markers"] = markers # Draw the marginal plots on the diagonal diag_kws = diag_kws.copy() diag_kws.setdefault("legend", False) if diag_kind == "hist": grid.map_diag(histplot, **diag_kws) elif diag_kind == "kde": diag_kws.setdefault("fill", True) diag_kws.setdefault("warn_singular", False) grid.map_diag(kdeplot, **diag_kws) # Maybe plot on the off-diagonals if diag_kind is not None: plotter = grid.map_offdiag else: plotter = grid.map if kind == "scatter": from .relational import scatterplot # Avoid circular import plotter(scatterplot, **plot_kws) elif kind == "reg": from .regression import regplot # Avoid circular import plotter(regplot, **plot_kws) elif kind == "kde": from .distributions import kdeplot # Avoid circular import plot_kws.setdefault("warn_singular", False) plotter(kdeplot, **plot_kws) elif kind == "hist": from .distributions import histplot # Avoid circular import plotter(histplot, **plot_kws) # Add a legend if hue is not None: grid.add_legend() grid.tight_layout() return grid def jointplot( data=None, *, x=None, y=None, hue=None, kind="scatter", height=6, ratio=5, space=.2, dropna=False, xlim=None, ylim=None, color=None, palette=None, hue_order=None, hue_norm=None, marginal_ticks=False, joint_kws=None, marginal_kws=None, **kwargs ): # Avoid circular imports from .relational import scatterplot from .regression import regplot, residplot from .distributions import histplot, kdeplot, _freedman_diaconis_bins if kwargs.pop("ax", None) is not None: msg = "Ignoring `ax`; jointplot is a figure-level function." warnings.warn(msg, UserWarning, stacklevel=2) # Set up empty default kwarg dicts joint_kws = {} if joint_kws is None else joint_kws.copy() joint_kws.update(kwargs) marginal_kws = {} if marginal_kws is None else marginal_kws.copy() # Handle deprecations of distplot-specific kwargs distplot_keys = [ "rug", "fit", "hist_kws", "norm_hist" "hist_kws", "rug_kws", ] unused_keys = [] for key in distplot_keys: if key in marginal_kws: unused_keys.append(key) marginal_kws.pop(key) if unused_keys and kind != "kde": msg = ( "The marginal plotting function has changed to `histplot`," " which does not accept the following argument(s): {}." ).format(", ".join(unused_keys)) warnings.warn(msg, UserWarning) # Validate the plot kind plot_kinds = ["scatter", "hist", "hex", "kde", "reg", "resid"] _check_argument("kind", plot_kinds, kind) # Raise early if using `hue` with a kind that does not support it if hue is not None and kind in ["hex", "reg", "resid"]: msg = f"Use of `hue` with `kind='{kind}'` is not currently supported." raise ValueError(msg) # Make a colormap based off the plot color # (Currently used only for kind="hex") if color is None: color = "C0" color_rgb = mpl.colors.colorConverter.to_rgb(color) colors = [set_hls_values(color_rgb, l=val) for val in np.linspace(1, 0, 12)] cmap = blend_palette(colors, as_cmap=True) # Matplotlib's hexbin plot is not na-robust if kind == "hex": dropna = True # Initialize the JointGrid object grid = JointGrid( data=data, x=x, y=y, hue=hue, palette=palette, hue_order=hue_order, hue_norm=hue_norm, dropna=dropna, height=height, ratio=ratio, space=space, xlim=xlim, ylim=ylim, marginal_ticks=marginal_ticks, ) if grid.hue is not None: marginal_kws.setdefault("legend", False) # Plot the data using the grid if kind.startswith("scatter"): joint_kws.setdefault("color", color) grid.plot_joint(scatterplot, **joint_kws) if grid.hue is None: marg_func = histplot else: marg_func = kdeplot marginal_kws.setdefault("warn_singular", False) marginal_kws.setdefault("fill", True) marginal_kws.setdefault("color", color) grid.plot_marginals(marg_func, **marginal_kws) elif kind.startswith("hist"): # TODO process pair parameters for bins, etc. and pass # to both joint and marginal plots joint_kws.setdefault("color", color) grid.plot_joint(histplot, **joint_kws) marginal_kws.setdefault("kde", False) marginal_kws.setdefault("color", color) marg_x_kws = marginal_kws.copy() marg_y_kws = marginal_kws.copy() pair_keys = "bins", "binwidth", "binrange" for key in pair_keys: if isinstance(joint_kws.get(key), tuple): x_val, y_val = joint_kws[key] marg_x_kws.setdefault(key, x_val) marg_y_kws.setdefault(key, y_val) histplot(data=data, x=x, hue=hue, **marg_x_kws, ax=grid.ax_marg_x) histplot(data=data, y=y, hue=hue, **marg_y_kws, ax=grid.ax_marg_y) elif kind.startswith("kde"): joint_kws.setdefault("color", color) joint_kws.setdefault("warn_singular", False) grid.plot_joint(kdeplot, **joint_kws) marginal_kws.setdefault("color", color) if "fill" in joint_kws: marginal_kws.setdefault("fill", joint_kws["fill"]) grid.plot_marginals(kdeplot, **marginal_kws) elif kind.startswith("hex"): x_bins = min(_freedman_diaconis_bins(grid.x), 50) y_bins = min(_freedman_diaconis_bins(grid.y), 50) gridsize = int(np.mean([x_bins, y_bins])) joint_kws.setdefault("gridsize", gridsize) joint_kws.setdefault("cmap", cmap) grid.plot_joint(plt.hexbin, **joint_kws) marginal_kws.setdefault("kde", False) marginal_kws.setdefault("color", color) grid.plot_marginals(histplot, **marginal_kws) elif kind.startswith("reg"): marginal_kws.setdefault("color", color) marginal_kws.setdefault("kde", True) grid.plot_marginals(histplot, **marginal_kws) joint_kws.setdefault("color", color) grid.plot_joint(regplot, **joint_kws) elif kind.startswith("resid"): joint_kws.setdefault("color", color) grid.plot_joint(residplot, **joint_kws) x, y = grid.ax_joint.collections[0].get_offsets().T marginal_kws.setdefault("color", color) histplot(x=x, hue=hue, ax=grid.ax_marg_x, **marginal_kws) histplot(y=y, hue=hue, ax=grid.ax_marg_y, **marginal_kws) # Make the main axes active in the matplotlib state machine plt.sca(grid.ax_joint) return grid jointplot.__doc__ = """\ Draw a plot of two variables with bivariate and univariate graphs. This function provides a convenient interface to the :class:`JointGrid` class, with several canned plot kinds. This is intended to be a fairly lightweight wrapper; if you need more flexibility, you should use :class:`JointGrid` directly. Parameters ---------- {params.core.data} {params.core.xy} {params.core.hue} kind : {{ "scatter" | "kde" | "hist" | "hex" | "reg" | "resid" }} Kind of plot to draw. See the examples for references to the underlying functions. height : numeric Size of the figure (it will be square). ratio : numeric Ratio of joint axes height to marginal axes height. space : numeric Space between the joint and marginal axes dropna : bool If True, remove observations that are missing from ``x`` and ``y``. {{x, y}}lim : pairs of numbers Axis limits to set before plotting. {params.core.color} {params.core.palette} {params.core.hue_order} {params.core.hue_norm} marginal_ticks : bool If False, suppress ticks on the count/density axis of the marginal plots. {{joint, marginal}}_kws : dicts Additional keyword arguments for the plot components. kwargs Additional keyword arguments are passed to the function used to draw the plot on the joint Axes, superseding items in the ``joint_kws`` dictionary. Returns ------- {returns.jointgrid} See Also -------- {seealso.jointgrid} {seealso.pairgrid} {seealso.pairplot} Examples -------- .. include:: ../docstrings/jointplot.rst """.format( params=_param_docs, returns=_core_docs["returns"], seealso=_core_docs["seealso"], ) ================================================ FILE: seaborn/categorical.py ================================================ from collections import namedtuple from textwrap import dedent import warnings from colorsys import rgb_to_hls from functools import partial import numpy as np import pandas as pd import matplotlib as mpl from matplotlib.cbook import normalize_kwargs from matplotlib.collections import PatchCollection from matplotlib.markers import MarkerStyle from matplotlib.patches import Rectangle import matplotlib.pyplot as plt from seaborn._core.typing import default, deprecated from seaborn._base import VectorPlotter, infer_orient, categorical_order from seaborn._stats.density import KDE from seaborn import utils from seaborn.utils import ( desaturate, _check_argument, _draw_figure, _default_color, _get_patch_legend_artist, _get_transform_functions, _scatter_legend_artist, _version_predates, ) from seaborn._compat import groupby_apply_include_groups from seaborn._statistics import ( EstimateAggregator, LetterValues, WeightedAggregator, ) from seaborn.palettes import light_palette from seaborn.axisgrid import FacetGrid, _facet_docs __all__ = [ "catplot", "stripplot", "swarmplot", "boxplot", "violinplot", "boxenplot", "pointplot", "barplot", "countplot", ] class _CategoricalPlotter(VectorPlotter): wide_structure = {"x": "@columns", "y": "@values", "hue": "@columns"} flat_structure = {"y": "@values"} _legend_attributes = ["color"] def __init__( self, data=None, variables={}, order=None, orient=None, require_numeric=False, color=None, legend="auto", ): super().__init__(data=data, variables=variables) # This method takes care of some bookkeeping that is necessary because the # original categorical plots (prior to the 2021 refactor) had some rules that # don't fit exactly into VectorPlotter logic. It may be wise to have a second # round of refactoring that moves the logic deeper, but this will keep things # relatively sensible for now. # For wide data, orient determines assignment to x/y differently from the # default VectorPlotter rules. If we do decide to make orient part of the # _base variable assignment, we'll want to figure out how to express that. if self.input_format == "wide" and orient in ["h", "y"]: self.plot_data = self.plot_data.rename(columns={"x": "y", "y": "x"}) orig_variables = set(self.variables) orig_x = self.variables.pop("x", None) orig_y = self.variables.pop("y", None) orig_x_type = self.var_types.pop("x", None) orig_y_type = self.var_types.pop("y", None) if "x" in orig_variables: self.variables["y"] = orig_x self.var_types["y"] = orig_x_type if "y" in orig_variables: self.variables["x"] = orig_y self.var_types["x"] = orig_y_type # Initially there was more special code for wide-form data where plots were # multi-colored by default and then either palette or color could be used. # We want to provide backwards compatibility for this behavior in a relatively # simply way, so we delete the hue information when color is specified. if ( self.input_format == "wide" and "hue" in self.variables and color is not None ): self.plot_data.drop("hue", axis=1) self.variables.pop("hue") # The concept of an "orientation" is important to the original categorical # plots, but there's no provision for it in VectorPlotter, so we need it here. # Note that it could be useful for the other functions in at least two ways # (orienting a univariate distribution plot from long-form data and selecting # the aggregation axis in lineplot), so we may want to eventually refactor it. self.orient = infer_orient( x=self.plot_data.get("x", None), y=self.plot_data.get("y", None), orient=orient, require_numeric=False, ) self.legend = legend # Short-circuit in the case of an empty plot if not self.has_xy_data: return # Categorical plots can be "univariate" in which case they get an anonymous # category label on the opposite axis. Note: this duplicates code in the core # scale_categorical function. We need to do it here because of the next line. if self.orient not in self.variables: self.variables[self.orient] = None self.var_types[self.orient] = "categorical" self.plot_data[self.orient] = "" # Categorical variables have discrete levels that we need to track cat_levels = categorical_order(self.plot_data[self.orient], order) self.var_levels[self.orient] = cat_levels def _hue_backcompat(self, color, palette, hue_order, force_hue=False): """Implement backwards compatibility for hue parametrization. Note: the force_hue parameter is used so that functions can be shown to pass existing tests during refactoring and then tested for new behavior. It can be removed after completion of the work. """ # The original categorical functions applied a palette to the categorical axis # by default. We want to require an explicit hue mapping, to be more consistent # with how things work elsewhere now. I don't think there's any good way to # do this gently -- because it's triggered by the default value of hue=None, # users would always get a warning, unless we introduce some sentinel "default" # argument for this change. That's possible, but asking users to set `hue=None` # on every call is annoying. # We are keeping the logic for implementing the old behavior in with the current # system so that (a) we can punt on that decision and (b) we can ensure that # refactored code passes old tests. default_behavior = color is None or palette is not None if force_hue and "hue" not in self.variables and default_behavior: self._redundant_hue = True self.plot_data["hue"] = self.plot_data[self.orient] self.variables["hue"] = self.variables[self.orient] self.var_types["hue"] = "categorical" hue_order = self.var_levels[self.orient] # Because we convert the categorical axis variable to string, # we need to update a dictionary palette too if isinstance(palette, dict): palette = {str(k): v for k, v in palette.items()} else: if "hue" in self.variables: redundant = (self.plot_data["hue"] == self.plot_data[self.orient]).all() else: redundant = False self._redundant_hue = redundant # Previously, categorical plots had a trick where color= could seed the palette. # Because that's an explicit parameterization, we are going to give it one # release cycle with a warning before removing. if "hue" in self.variables and palette is None and color is not None: if not isinstance(color, str): color = mpl.colors.to_hex(color) palette = f"dark:{color}" msg = ( "\n\nSetting a gradient palette using color= is deprecated and will be " f"removed in v0.14.0. Set `palette='{palette}'` for the same effect.\n" ) warnings.warn(msg, FutureWarning, stacklevel=3) return palette, hue_order def _palette_without_hue_backcompat(self, palette, hue_order): """Provide one cycle where palette= implies hue= when not provided""" if "hue" not in self.variables and palette is not None: msg = ( "\n\nPassing `palette` without assigning `hue` is deprecated " f"and will be removed in v0.14.0. Assign the `{self.orient}` variable " "to `hue` and set `legend=False` for the same effect.\n" ) warnings.warn(msg, FutureWarning, stacklevel=3) self.legend = False self.plot_data["hue"] = self.plot_data[self.orient] self.variables["hue"] = self.variables.get(self.orient) self.var_types["hue"] = self.var_types.get(self.orient) hue_order = self.var_levels.get(self.orient) self._var_levels.pop("hue", None) return hue_order def _point_kwargs_backcompat(self, scale, join, kwargs): """Provide two cycles where scale= and join= work, but redirect to kwargs.""" if scale is not deprecated: lw = mpl.rcParams["lines.linewidth"] * 1.8 * scale mew = lw * .75 ms = lw * 2 msg = ( "\n\n" "The `scale` parameter is deprecated and will be removed in v0.15.0. " "You can now control the size of each plot element using matplotlib " "`Line2D` parameters (e.g., `linewidth`, `markersize`, etc.)." "\n" ) warnings.warn(msg, stacklevel=3) kwargs.update(linewidth=lw, markeredgewidth=mew, markersize=ms) if join is not deprecated: msg = ( "\n\n" "The `join` parameter is deprecated and will be removed in v0.15.0." ) if not join: msg += ( " You can remove the line between points with `linestyle='none'`." ) kwargs.update(linestyle="") msg += "\n" warnings.warn(msg, stacklevel=3) def _err_kws_backcompat(self, err_kws, errcolor, errwidth, capsize): """Provide two cycles where existing signature-level err_kws are handled.""" def deprecate_err_param(name, key, val): if val is deprecated: return suggest = f"err_kws={{'{key}': {val!r}}}" msg = ( f"\n\nThe `{name}` parameter is deprecated. And will be removed " f"in v0.15.0. Pass `{suggest}` instead.\n" ) warnings.warn(msg, FutureWarning, stacklevel=4) err_kws[key] = val if errcolor is not None: deprecate_err_param("errcolor", "color", errcolor) deprecate_err_param("errwidth", "linewidth", errwidth) if capsize is None: capsize = 0 msg = ( "\n\nPassing `capsize=None` is deprecated and will be removed " "in v0.15.0. Pass `capsize=0` to disable caps.\n" ) warnings.warn(msg, FutureWarning, stacklevel=3) return err_kws, capsize def _violin_scale_backcompat(self, scale, scale_hue, density_norm, common_norm): """Provide two cycles of backcompat for scale kwargs""" if scale is not deprecated: density_norm = scale msg = ( "\n\nThe `scale` parameter has been renamed and will be removed " f"in v0.15.0. Pass `density_norm={scale!r}` for the same effect." ) warnings.warn(msg, FutureWarning, stacklevel=3) if scale_hue is not deprecated: common_norm = scale_hue msg = ( "\n\nThe `scale_hue` parameter has been replaced and will be removed " f"in v0.15.0. Pass `common_norm={not scale_hue}` for the same effect." ) warnings.warn(msg, FutureWarning, stacklevel=3) return density_norm, common_norm def _violin_bw_backcompat(self, bw, bw_method): """Provide two cycles of backcompat for violin bandwidth parameterization.""" if bw is not deprecated: bw_method = bw msg = dedent(f"""\n The `bw` parameter is deprecated in favor of `bw_method`/`bw_adjust`. Setting `bw_method={bw!r}`, but please see docs for the new parameters and update your code. This will become an error in seaborn v0.15.0. """) warnings.warn(msg, FutureWarning, stacklevel=3) return bw_method def _boxen_scale_backcompat(self, scale, width_method): """Provide two cycles of backcompat for scale kwargs""" if scale is not deprecated: width_method = scale msg = ( "\n\nThe `scale` parameter has been renamed to `width_method` and " f"will be removed in v0.15. Pass `width_method={scale!r}" ) if scale == "area": msg += ", but note that the result for 'area' will appear different." else: msg += " for the same effect." warnings.warn(msg, FutureWarning, stacklevel=3) return width_method def _complement_color(self, color, base_color, hue_map): """Allow a color to be set automatically using a basis of comparison.""" if color == "gray": msg = ( 'Use "auto" to set automatic grayscale colors. From v0.14.0, ' '"gray" will default to matplotlib\'s definition.' ) warnings.warn(msg, FutureWarning, stacklevel=3) color = "auto" elif color is None or color is default: color = "auto" if color != "auto": return color if hue_map.lookup_table is None: if base_color is None: return None basis = [mpl.colors.to_rgb(base_color)] else: basis = [mpl.colors.to_rgb(c) for c in hue_map.lookup_table.values()] unique_colors = np.unique(basis, axis=0) light_vals = [rgb_to_hls(*rgb[:3])[1] for rgb in unique_colors] lum = min(light_vals) * .6 return (lum, lum, lum) def _map_prop_with_hue(self, name, value, fallback, plot_kws): """Support pointplot behavior of modifying the marker/linestyle with hue.""" if value is default: value = plot_kws.pop(name, fallback) if "hue" in self.variables: levels = self._hue_map.levels if isinstance(value, list): mapping = {k: v for k, v in zip(levels, value)} else: mapping = {k: value for k in levels} else: mapping = {None: value} return mapping def _adjust_cat_axis(self, ax, axis): """Set ticks and limits for a categorical variable.""" # Note: in theory, this could happen in _attach for all categorical axes # But two reasons not to do that: # - If it happens before plotting, autoscaling messes up the plot limits # - It would change existing plots from other seaborn functions if self.var_types[axis] != "categorical": return # If both x/y data are empty, the correct way to set up the plot is # somewhat undefined; because we don't add null category data to the plot in # this case we don't *have* a categorical axis (yet), so best to just bail. if self.plot_data[axis].empty: return # We can infer the total number of categories (including those from previous # plots that are not part of the plot we are currently making) from the number # of ticks, which matplotlib sets up while doing unit conversion. This feels # slightly risky, as if we are relying on something that may be a matplotlib # implementation detail. But I cannot think of a better way to keep track of # the state from previous categorical calls (see GH2516 for context) n = len(getattr(ax, f"get_{axis}ticks")()) if axis == "x": ax.xaxis.grid(False) ax.set_xlim(-.5, n - .5, auto=None) else: ax.yaxis.grid(False) # Note limits that correspond to previously-inverted y axis ax.set_ylim(n - .5, -.5, auto=None) def _dodge_needed(self): """Return True when use of `hue` would cause overlaps.""" groupers = list({self.orient, "col", "row"} & set(self.variables)) if "hue" in self.variables: orient = self.plot_data[groupers].value_counts() paired = self.plot_data[[*groupers, "hue"]].value_counts() return orient.size != paired.size return False def _dodge(self, keys, data): """Apply a dodge transform to coordinates in place.""" if "hue" not in self.variables: # Short-circuit if hue variable was not assigned # We could potentially warn when hue=None, dodge=True, user may be confused # But I think it's fine to just treat it as a no-op. return hue_idx = self._hue_map.levels.index(keys["hue"]) n = len(self._hue_map.levels) data["width"] /= n full_width = data["width"] * n offset = data["width"] * hue_idx + data["width"] / 2 - full_width / 2 data[self.orient] += offset def _invert_scale(self, ax, data, vars=("x", "y")): """Undo scaling after computation so data are plotted correctly.""" for var in vars: _, inv = _get_transform_functions(ax, var[0]) if var == self.orient and "width" in data: hw = data["width"] / 2 data["edge"] = inv(data[var] - hw) data["width"] = inv(data[var] + hw) - data["edge"].to_numpy() for suf in ["", "min", "max"]: if (col := f"{var}{suf}") in data: data[col] = inv(data[col]) def _configure_legend(self, ax, func, common_kws=None, semantic_kws=None): if self.legend == "auto": show_legend = not self._redundant_hue and self.input_format != "wide" else: show_legend = bool(self.legend) if show_legend: self.add_legend_data(ax, func, common_kws, semantic_kws=semantic_kws) handles, _ = ax.get_legend_handles_labels() if handles: ax.legend(title=self.legend_title) @property def _native_width(self): """Return unit of width separating categories on native numeric scale.""" # Categorical data always have a unit width if self.var_types[self.orient] == "categorical": return 1 # Otherwise, define the width as the smallest space between observations unique_values = np.unique(self.comp_data[self.orient]) if len(unique_values) > 1: native_width = np.nanmin(np.diff(unique_values)) else: native_width = 1 return native_width def _nested_offsets(self, width, dodge): """Return offsets for each hue level for dodged plots.""" offsets = None if "hue" in self.variables and self._hue_map.levels is not None: n_levels = len(self._hue_map.levels) if dodge: each_width = width / n_levels offsets = np.linspace(0, width - each_width, n_levels) offsets -= offsets.mean() else: offsets = np.zeros(n_levels) return offsets # Note that the plotting methods here aim (in most cases) to produce the # exact same artists as the original (pre 0.12) version of the code, so # there is some weirdness that might not otherwise be clean or make sense in # this context, such as adding empty artists for combinations of variables # with no observations def plot_strips( self, jitter, dodge, color, plot_kws, ): width = .8 * self._native_width offsets = self._nested_offsets(width, dodge) if jitter is True: jlim = 0.1 else: jlim = float(jitter) if "hue" in self.variables and dodge and self._hue_map.levels is not None: jlim /= len(self._hue_map.levels) jlim *= self._native_width jitterer = partial(np.random.uniform, low=-jlim, high=+jlim) iter_vars = [self.orient] if dodge: iter_vars.append("hue") ax = self.ax dodge_move = jitter_move = 0 if "marker" in plot_kws and not MarkerStyle(plot_kws["marker"]).is_filled(): plot_kws.pop("edgecolor", None) for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=True): ax = self._get_axes(sub_vars) if offsets is not None and (offsets != 0).any(): dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)] jitter_move = jitterer(size=len(sub_data)) if len(sub_data) > 1 else 0 adjusted_data = sub_data[self.orient] + dodge_move + jitter_move sub_data[self.orient] = adjusted_data self._invert_scale(ax, sub_data) points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws) if "hue" in self.variables: points.set_facecolors(self._hue_map(sub_data["hue"])) self._configure_legend(ax, _scatter_legend_artist, common_kws=plot_kws) def plot_swarms( self, dodge, color, warn_thresh, plot_kws, ): width = .8 * self._native_width offsets = self._nested_offsets(width, dodge) iter_vars = [self.orient] if dodge: iter_vars.append("hue") ax = self.ax point_collections = {} dodge_move = 0 if "marker" in plot_kws and not MarkerStyle(plot_kws["marker"]).is_filled(): plot_kws.pop("edgecolor", None) for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=True): ax = self._get_axes(sub_vars) if offsets is not None: dodge_move = offsets[sub_data["hue"].map(self._hue_map.levels.index)] if not sub_data.empty: sub_data[self.orient] = sub_data[self.orient] + dodge_move self._invert_scale(ax, sub_data) points = ax.scatter(sub_data["x"], sub_data["y"], color=color, **plot_kws) if "hue" in self.variables: points.set_facecolors(self._hue_map(sub_data["hue"])) if not sub_data.empty: point_collections[(ax, sub_data[self.orient].iloc[0])] = points beeswarm = Beeswarm(width=width, orient=self.orient, warn_thresh=warn_thresh) for (ax, center), points in point_collections.items(): if points.get_offsets().shape[0] > 1: def draw(points, renderer, *, center=center): beeswarm(points, center) if self.orient == "y": scalex = False scaley = ax.get_autoscaley_on() else: scalex = ax.get_autoscalex_on() scaley = False # This prevents us from undoing the nice categorical axis limits # set in _adjust_cat_axis, because that method currently leave # the autoscale flag in its original setting. It may be better # to disable autoscaling there to avoid needing to do this. fixed_scale = self.var_types[self.orient] == "categorical" ax.update_datalim(points.get_datalim(ax.transData)) if not fixed_scale and (scalex or scaley): ax.autoscale_view(scalex=scalex, scaley=scaley) super(points.__class__, points).draw(renderer) points.draw = draw.__get__(points) _draw_figure(ax.figure) self._configure_legend(ax, _scatter_legend_artist, plot_kws) def plot_boxes( self, width, dodge, gap, fill, whis, color, linecolor, linewidth, fliersize, plot_kws, # TODO rename user_kws? ): iter_vars = ["hue"] value_var = {"x": "y", "y": "x"}[self.orient] def get_props(element, artist=mpl.lines.Line2D): return normalize_kwargs(plot_kws.pop(f"{element}props", {}), artist) if not fill and linewidth is None: linewidth = mpl.rcParams["lines.linewidth"] bootstrap = plot_kws.pop("bootstrap", mpl.rcParams["boxplot.bootstrap"]) plot_kws.setdefault("shownotches", plot_kws.pop("notch", False)) box_artist = mpl.patches.Rectangle if fill else mpl.lines.Line2D props = { "box": get_props("box", box_artist), "median": get_props("median"), "whisker": get_props("whisker"), "flier": get_props("flier"), "cap": get_props("cap"), } props["median"].setdefault("solid_capstyle", "butt") props["whisker"].setdefault("solid_capstyle", "butt") props["flier"].setdefault("markersize", fliersize) orientation = {"x": "vertical", "y": "horizontal"}[self.orient] ax = self.ax for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=False): ax = self._get_axes(sub_vars) grouped = sub_data.groupby(self.orient)[value_var] positions = sorted(sub_data[self.orient].unique().astype(float)) value_data = [x.to_numpy() for _, x in grouped] stats = pd.DataFrame(mpl.cbook.boxplot_stats(value_data, whis=whis, bootstrap=bootstrap)) orig_width = width * self._native_width data = pd.DataFrame({self.orient: positions, "width": orig_width}) if dodge: self._dodge(sub_vars, data) if gap: data["width"] *= 1 - gap capwidth = plot_kws.get("capwidths", 0.5 * data["width"]) self._invert_scale(ax, data) _, inv = _get_transform_functions(ax, value_var) for stat in ["mean", "med", "q1", "q3", "cilo", "cihi", "whislo", "whishi"]: stats[stat] = inv(stats[stat]) stats["fliers"] = stats["fliers"].map(inv) linear_orient_scale = getattr(ax, f"get_{self.orient}scale")() == "linear" maincolor = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color if fill: boxprops = { "facecolor": maincolor, "edgecolor": linecolor, **props["box"] } medianprops = {"color": linecolor, **props["median"]} whiskerprops = {"color": linecolor, **props["whisker"]} flierprops = {"markeredgecolor": linecolor, **props["flier"]} capprops = {"color": linecolor, **props["cap"]} else: boxprops = {"color": maincolor, **props["box"]} medianprops = {"color": maincolor, **props["median"]} whiskerprops = {"color": maincolor, **props["whisker"]} flierprops = {"markeredgecolor": maincolor, **props["flier"]} capprops = {"color": maincolor, **props["cap"]} if linewidth is not None: for prop_dict in [boxprops, medianprops, whiskerprops, capprops]: prop_dict.setdefault("linewidth", linewidth) default_kws = dict( bxpstats=stats.to_dict("records"), positions=data[self.orient], # Set width to 0 to avoid going out of domain widths=data["width"] if linear_orient_scale else 0, patch_artist=fill, manage_ticks=False, boxprops=boxprops, medianprops=medianprops, whiskerprops=whiskerprops, flierprops=flierprops, capprops=capprops, # Added in matplotlib 3.10; see below # orientation=orientation **( {"vert": orientation == "vertical"} if _version_predates(mpl, "3.10.0") else {"orientation": orientation} ), # added in matplotlib 3.6.0; see below # capwidths=capwidth, **( {} if _version_predates(mpl, "3.6.0") else {"capwidths": capwidth} ) ) boxplot_kws = {**default_kws, **plot_kws} artists = ax.bxp(**boxplot_kws) # Reset artist widths after adding so everything stays positive ori_idx = ["x", "y"].index(self.orient) if not linear_orient_scale: for i, box in enumerate(data.to_dict("records")): p0 = box["edge"] p1 = box["edge"] + box["width"] if artists["boxes"]: box_artist = artists["boxes"][i] if fill: box_verts = box_artist.get_path().vertices.T else: box_verts = box_artist.get_data() box_verts[ori_idx][0] = p0 box_verts[ori_idx][3:] = p0 box_verts[ori_idx][1:3] = p1 if not fill: # When fill is True, the data get changed in place box_artist.set_data(box_verts) ax.update_datalim( np.transpose(box_verts), updatex=self.orient == "x", updatey=self.orient == "y", ) if artists["medians"]: verts = artists["medians"][i].get_xydata().T verts[ori_idx][:] = p0, p1 artists["medians"][i].set_data(verts) if artists["caps"]: f_fwd, f_inv = _get_transform_functions(ax, self.orient) for line in artists["caps"][2 * i:2 * i + 2]: p0 = f_inv(f_fwd(box[self.orient]) - capwidth[i] / 2) p1 = f_inv(f_fwd(box[self.orient]) + capwidth[i] / 2) verts = line.get_xydata().T verts[ori_idx][:] = p0, p1 line.set_data(verts) ax.add_container(BoxPlotContainer(artists)) legend_artist = _get_patch_legend_artist(fill) self._configure_legend(ax, legend_artist, boxprops) def plot_boxens( self, width, dodge, gap, fill, color, linecolor, linewidth, width_method, k_depth, outlier_prop, trust_alpha, showfliers, box_kws, flier_kws, line_kws, plot_kws, ): iter_vars = [self.orient, "hue"] value_var = {"x": "y", "y": "x"}[self.orient] estimator = LetterValues(k_depth, outlier_prop, trust_alpha) width_method_options = ["exponential", "linear", "area"] _check_argument("width_method", width_method_options, width_method) box_kws = plot_kws if box_kws is None else {**plot_kws, **box_kws} flier_kws = {} if flier_kws is None else flier_kws.copy() line_kws = {} if line_kws is None else line_kws.copy() if linewidth is None: if fill: linewidth = 0.5 * mpl.rcParams["lines.linewidth"] else: linewidth = mpl.rcParams["lines.linewidth"] ax = self.ax for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=False): ax = self._get_axes(sub_vars) _, inv_ori = _get_transform_functions(ax, self.orient) _, inv_val = _get_transform_functions(ax, value_var) # Statistics lv_data = estimator(sub_data[value_var]) n = lv_data["k"] * 2 - 1 vals = lv_data["values"] pos_data = pd.DataFrame({ self.orient: [sub_vars[self.orient]], "width": [width * self._native_width], }) if dodge: self._dodge(sub_vars, pos_data) if gap: pos_data["width"] *= 1 - gap # Letter-value boxes levels = lv_data["levels"] exponent = (levels - 1 - lv_data["k"]).astype(float) if width_method == "linear": rel_widths = levels + 1 elif width_method == "exponential": rel_widths = 2 ** exponent elif width_method == "area": tails = levels < (lv_data["k"] - 1) rel_widths = 2 ** (exponent - tails) / np.diff(lv_data["values"]) center = pos_data[self.orient].item() widths = rel_widths / rel_widths.max() * pos_data["width"].item() box_vals = inv_val(vals) box_pos = inv_ori(center - widths / 2) box_heights = inv_val(vals[1:]) - inv_val(vals[:-1]) box_widths = inv_ori(center + widths / 2) - inv_ori(center - widths / 2) maincolor = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color flier_colors = { "facecolor": "none", "edgecolor": ".45" if fill else maincolor } if fill: cmap = light_palette(maincolor, as_cmap=True) boxcolors = cmap(2 ** ((exponent + 2) / 3)) else: boxcolors = maincolor boxen = [] for i in range(n): if self.orient == "x": xy = (box_pos[i], box_vals[i]) w, h = (box_widths[i], box_heights[i]) else: xy = (box_vals[i], box_pos[i]) w, h = (box_heights[i], box_widths[i]) boxen.append(Rectangle(xy, w, h)) if fill: box_colors = {"facecolors": boxcolors, "edgecolors": linecolor} else: box_colors = {"facecolors": "none", "edgecolors": boxcolors} collection_kws = {**box_colors, "linewidth": linewidth, **box_kws} ax.add_collection(PatchCollection(boxen, **collection_kws), autolim=False) ax.update_datalim( np.column_stack([box_vals, box_vals]), updatex=self.orient == "y", updatey=self.orient == "x", ) # Median line med = lv_data["median"] hw = pos_data["width"].item() / 2 if self.orient == "x": x, y = inv_ori([center - hw, center + hw]), inv_val([med, med]) else: x, y = inv_val([med, med]), inv_ori([center - hw, center + hw]) default_kws = { "color": linecolor if fill else maincolor, "solid_capstyle": "butt", "linewidth": 1.25 * linewidth, } ax.plot(x, y, **{**default_kws, **line_kws}) # Outliers ("fliers") if showfliers: vals = inv_val(lv_data["fliers"]) pos = np.full(len(vals), inv_ori(pos_data[self.orient].item())) x, y = (pos, vals) if self.orient == "x" else (vals, pos) ax.scatter(x, y, **{**flier_colors, "s": 25, **flier_kws}) ax.autoscale_view(scalex=self.orient == "y", scaley=self.orient == "x") legend_artist = _get_patch_legend_artist(fill) common_kws = {**box_kws, "linewidth": linewidth, "edgecolor": linecolor} self._configure_legend(ax, legend_artist, common_kws) def plot_violins( self, width, dodge, gap, split, color, fill, linecolor, linewidth, inner, density_norm, common_norm, kde_kws, inner_kws, plot_kws, ): iter_vars = [self.orient, "hue"] value_var = {"x": "y", "y": "x"}[self.orient] inner_options = ["box", "quart", "stick", "point", None] _check_argument("inner", inner_options, inner, prefix=True) _check_argument("density_norm", ["area", "count", "width"], density_norm) if linewidth is None: if fill: linewidth = 1.25 * mpl.rcParams["patch.linewidth"] else: linewidth = mpl.rcParams["lines.linewidth"] if inner is not None and inner.startswith("box"): box_width = inner_kws.pop("box_width", linewidth * 4.5) whis_width = inner_kws.pop("whis_width", box_width / 3) marker = inner_kws.pop("marker", "_" if self.orient == "x" else "|") kde = KDE(**kde_kws) ax = self.ax violin_data = [] # Iterate through all the data splits once to compute the KDEs for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=False): sub_data["weight"] = sub_data.get("weights", 1) stat_data = kde._transform(sub_data, value_var, []) maincolor = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color if not fill: linecolor = maincolor maincolor = "none" default_kws = dict( facecolor=maincolor, edgecolor=linecolor, linewidth=linewidth, ) violin_data.append({ "position": sub_vars[self.orient], "observations": sub_data[value_var], "density": stat_data["density"], "support": stat_data[value_var], "kwargs": {**default_kws, **plot_kws}, "sub_vars": sub_vars, "ax": self._get_axes(sub_vars), }) # Once we've computed all the KDEs, get statistics for normalization def vars_to_key(sub_vars): return tuple((k, v) for k, v in sub_vars.items() if k != self.orient) norm_keys = [vars_to_key(violin["sub_vars"]) for violin in violin_data] if common_norm: common_max_density = np.nanmax([v["density"].max() for v in violin_data]) common_max_count = np.nanmax([len(v["observations"]) for v in violin_data]) max_density = {key: common_max_density for key in norm_keys} max_count = {key: common_max_count for key in norm_keys} else: with warnings.catch_warnings(): # Ignore warning when all violins are singular; it's not important warnings.filterwarnings('ignore', "All-NaN (slice|axis) encountered") max_density = { key: np.nanmax([ v["density"].max() for v in violin_data if vars_to_key(v["sub_vars"]) == key ]) for key in norm_keys } max_count = { key: np.nanmax([ len(v["observations"]) for v in violin_data if vars_to_key(v["sub_vars"]) == key ]) for key in norm_keys } real_width = width * self._native_width # Now iterate through the violins again to apply the normalization and plot for violin in violin_data: index = pd.RangeIndex(0, max(len(violin["support"]), 1)) data = pd.DataFrame({ self.orient: violin["position"], value_var: violin["support"], "density": violin["density"], "width": real_width, }, index=index) if dodge: self._dodge(violin["sub_vars"], data) if gap: data["width"] *= 1 - gap # Normalize the density across the distribution(s) and relative to the width norm_key = vars_to_key(violin["sub_vars"]) hw = data["width"] / 2 peak_density = violin["density"].max() if np.isnan(peak_density): span = 1 elif density_norm == "area": span = data["density"] / max_density[norm_key] elif density_norm == "count": count = len(violin["observations"]) span = data["density"] / peak_density * (count / max_count[norm_key]) elif density_norm == "width": span = data["density"] / peak_density span = span * hw * (2 if split else 1) # Handle split violins (i.e. asymmetric spans) right_side = ( 0 if "hue" not in self.variables else self._hue_map.levels.index(violin["sub_vars"]["hue"]) % 2 ) if split: offsets = (hw, span - hw) if right_side else (span - hw, hw) else: offsets = span, span ax = violin["ax"] _, invx = _get_transform_functions(ax, "x") _, invy = _get_transform_functions(ax, "y") inv_pos = {"x": invx, "y": invy}[self.orient] inv_val = {"x": invx, "y": invy}[value_var] linecolor = violin["kwargs"]["edgecolor"] # Handle singular datasets (one or more observations with no variance if np.isnan(peak_density): pos = data[self.orient].iloc[0] val = violin["observations"].mean() if self.orient == "x": x, y = [pos - offsets[0], pos + offsets[1]], [val, val] else: x, y = [val, val], [pos - offsets[0], pos + offsets[1]] ax.plot(invx(x), invy(y), color=linecolor, linewidth=linewidth) continue # Plot the main violin body plot_func = {"x": ax.fill_betweenx, "y": ax.fill_between}[self.orient] plot_func( inv_val(data[value_var]), inv_pos(data[self.orient] - offsets[0]), inv_pos(data[self.orient] + offsets[1]), **violin["kwargs"] ) # Adjust the observation data obs = violin["observations"] pos_dict = {self.orient: violin["position"], "width": real_width} if dodge: self._dodge(violin["sub_vars"], pos_dict) if gap: pos_dict["width"] *= (1 - gap) # --- Plot the inner components if inner is None: continue elif inner.startswith("point"): pos = np.array([pos_dict[self.orient]] * len(obs)) if split: pos += (-1 if right_side else 1) * pos_dict["width"] / 2 x, y = (pos, obs) if self.orient == "x" else (obs, pos) kws = { "color": linecolor, "edgecolor": linecolor, "s": (linewidth * 2) ** 2, "zorder": violin["kwargs"].get("zorder", 2) + 1, **inner_kws, } ax.scatter(invx(x), invy(y), **kws) elif inner.startswith("stick"): pos0 = np.interp(obs, data[value_var], data[self.orient] - offsets[0]) pos1 = np.interp(obs, data[value_var], data[self.orient] + offsets[1]) pos_pts = np.stack([inv_pos(pos0), inv_pos(pos1)]) val_pts = np.stack([inv_val(obs), inv_val(obs)]) segments = np.stack([pos_pts, val_pts]).transpose(2, 1, 0) if self.orient == "y": segments = segments[:, :, ::-1] kws = { "color": linecolor, "linewidth": linewidth / 2, **inner_kws, } lines = mpl.collections.LineCollection(segments, **kws) ax.add_collection(lines, autolim=False) elif inner.startswith("quart"): stats = np.percentile(obs, [25, 50, 75]) pos0 = np.interp(stats, data[value_var], data[self.orient] - offsets[0]) pos1 = np.interp(stats, data[value_var], data[self.orient] + offsets[1]) pos_pts = np.stack([inv_pos(pos0), inv_pos(pos1)]) val_pts = np.stack([inv_val(stats), inv_val(stats)]) segments = np.stack([pos_pts, val_pts]).transpose(2, 0, 1) if self.orient == "y": segments = segments[:, ::-1, :] dashes = [(1.25, .75), (2.5, 1), (1.25, .75)] for i, segment in enumerate(segments): kws = { "color": linecolor, "linewidth": linewidth, "dashes": dashes[i], **inner_kws, } ax.plot(*segment, **kws) elif inner.startswith("box"): stats = mpl.cbook.boxplot_stats(obs)[0] pos = np.array(pos_dict[self.orient]) if split: pos += (-1 if right_side else 1) * pos_dict["width"] / 2 pos = [pos, pos], [pos, pos], [pos] val = ( [stats["whislo"], stats["whishi"]], [stats["q1"], stats["q3"]], [stats["med"]] ) if self.orient == "x": (x0, x1, x2), (y0, y1, y2) = pos, val else: (x0, x1, x2), (y0, y1, y2) = val, pos if split: offset = (1 if right_side else -1) * box_width / 72 / 2 dx, dy = (offset, 0) if self.orient == "x" else (0, -offset) trans = ax.transData + mpl.transforms.ScaledTranslation( dx, dy, ax.figure.dpi_scale_trans, ) else: trans = ax.transData line_kws = { "color": linecolor, "transform": trans, **inner_kws, "linewidth": whis_width, } ax.plot(invx(x0), invy(y0), **line_kws) line_kws["linewidth"] = box_width ax.plot(invx(x1), invy(y1), **line_kws) dot_kws = { "marker": marker, "markersize": box_width / 1.2, "markeredgewidth": box_width / 5, "transform": trans, **inner_kws, "markeredgecolor": "w", "markerfacecolor": "w", "color": linecolor, # simplify tests } ax.plot(invx(x2), invy(y2), **dot_kws) legend_artist = _get_patch_legend_artist(fill) common_kws = {**plot_kws, "linewidth": linewidth, "edgecolor": linecolor} self._configure_legend(ax, legend_artist, common_kws) def plot_points( self, aggregator, markers, linestyles, dodge, color, capsize, err_kws, plot_kws, ): agg_var = {"x": "y", "y": "x"}[self.orient] iter_vars = ["hue"] plot_kws = normalize_kwargs(plot_kws, mpl.lines.Line2D) plot_kws.setdefault("linewidth", mpl.rcParams["lines.linewidth"] * 1.8) plot_kws.setdefault("markeredgewidth", plot_kws["linewidth"] * 0.75) plot_kws.setdefault("markersize", plot_kws["linewidth"] * np.sqrt(2 * np.pi)) markers = self._map_prop_with_hue("marker", markers, "o", plot_kws) linestyles = self._map_prop_with_hue("linestyle", linestyles, "-", plot_kws) base_positions = self.var_levels[self.orient] if self.var_types[self.orient] == "categorical": min_cat_val = int(self.comp_data[self.orient].min()) max_cat_val = int(self.comp_data[self.orient].max()) base_positions = [i for i in range(min_cat_val, max_cat_val + 1)] n_hue_levels = 0 if self._hue_map.levels is None else len(self._hue_map.levels) if dodge is True: dodge = .025 * n_hue_levels ax = self.ax for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=True): ax = self._get_axes(sub_vars) ori_axis = getattr(ax, f"{self.orient}axis") transform, _ = _get_transform_functions(ax, self.orient) positions = transform(ori_axis.convert_units(base_positions)) agg_data = sub_data if sub_data.empty else ( sub_data .groupby(self.orient) .apply(aggregator, agg_var, **groupby_apply_include_groups(False)) .reindex(pd.Index(positions, name=self.orient)) .reset_index() ) if dodge: hue_idx = self._hue_map.levels.index(sub_vars["hue"]) step_size = dodge / (n_hue_levels - 1) offset = -dodge / 2 + step_size * hue_idx agg_data[self.orient] += offset * self._native_width self._invert_scale(ax, agg_data) sub_kws = plot_kws.copy() sub_kws.update( marker=markers[sub_vars.get("hue")], linestyle=linestyles[sub_vars.get("hue")], color=self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color, ) line, = ax.plot(agg_data["x"], agg_data["y"], **sub_kws) sub_err_kws = err_kws.copy() line_props = line.properties() for prop in ["color", "linewidth", "alpha", "zorder"]: sub_err_kws.setdefault(prop, line_props[prop]) if aggregator.error_method is not None: self.plot_errorbars(ax, agg_data, capsize, sub_err_kws) legend_artist = partial(mpl.lines.Line2D, [], []) semantic_kws = {"hue": {"marker": markers, "linestyle": linestyles}} self._configure_legend(ax, legend_artist, sub_kws, semantic_kws) def plot_bars( self, aggregator, dodge, gap, width, fill, color, capsize, err_kws, plot_kws, ): agg_var = {"x": "y", "y": "x"}[self.orient] iter_vars = ["hue"] ax = self.ax if self._hue_map.levels is None: dodge = False if dodge and capsize is not None: capsize = capsize / len(self._hue_map.levels) if not fill: plot_kws.setdefault("linewidth", 1.5 * mpl.rcParams["lines.linewidth"]) err_kws.setdefault("linewidth", 1.5 * mpl.rcParams["lines.linewidth"]) for sub_vars, sub_data in self.iter_data(iter_vars, from_comp_data=True, allow_empty=True): ax = self._get_axes(sub_vars) agg_data = sub_data if sub_data.empty else ( sub_data .groupby(self.orient) .apply(aggregator, agg_var, **groupby_apply_include_groups(False)) .reset_index() ) agg_data["width"] = width * self._native_width if dodge: self._dodge(sub_vars, agg_data) if gap: agg_data["width"] *= 1 - gap agg_data["edge"] = agg_data[self.orient] - agg_data["width"] / 2 self._invert_scale(ax, agg_data) if self.orient == "x": bar_func = ax.bar kws = dict( x=agg_data["edge"], height=agg_data["y"], width=agg_data["width"] ) else: bar_func = ax.barh kws = dict( y=agg_data["edge"], width=agg_data["x"], height=agg_data["width"] ) main_color = self._hue_map(sub_vars["hue"]) if "hue" in sub_vars else color # Set both color and facecolor for property cycle logic kws["align"] = "edge" if fill: kws.update(color=main_color, facecolor=main_color) else: kws.update(color=main_color, edgecolor=main_color, facecolor="none") bar_func(**{**kws, **plot_kws}) if aggregator.error_method is not None: self.plot_errorbars( ax, agg_data, capsize, {"color": ".26" if fill else main_color, **err_kws} ) legend_artist = _get_patch_legend_artist(fill) self._configure_legend(ax, legend_artist, plot_kws) def plot_errorbars(self, ax, data, capsize, err_kws): var = {"x": "y", "y": "x"}[self.orient] for row in data.to_dict("records"): row = dict(row) pos = np.array([row[self.orient], row[self.orient]]) val = np.array([row[f"{var}min"], row[f"{var}max"]]) if capsize: cw = capsize * self._native_width / 2 scl, inv = _get_transform_functions(ax, self.orient) cap = inv(scl(pos[0]) - cw), inv(scl(pos[1]) + cw) pos = np.concatenate([ [*cap, np.nan], pos, [np.nan, *cap] ]) val = np.concatenate([ [val[0], val[0], np.nan], val, [np.nan, val[-1], val[-1]], ]) if self.orient == "x": args = pos, val else: args = val, pos ax.plot(*args, **err_kws) class _CategoricalAggPlotter(_CategoricalPlotter): flat_structure = {"x": "@index", "y": "@values"} _categorical_docs = dict( # Shared narrative docs categorical_narrative=dedent("""\ See the :ref:`tutorial ` for more information. .. note:: By default, this function treats one of the variables as categorical and draws data at ordinal positions (0, 1, ... n) on the relevant axis. As of version 0.13.0, this can be disabled by setting `native_scale=True`. """), # Shared function parameters input_params=dedent("""\ x, y, hue : names of variables in `data` or vector data Inputs for plotting long-form data. See examples for interpretation.\ """), categorical_data=dedent("""\ data : DataFrame, Series, dict, array, or list of arrays Dataset for plotting. If `x` and `y` are absent, this is interpreted as wide-form. Otherwise it is expected to be long-form.\ """), order_vars=dedent("""\ order, hue_order : lists of strings Order to plot the categorical levels in; otherwise the levels are inferred from the data objects.\ """), stat_api_params=dedent("""\ estimator : string or callable that maps vector -> scalar Statistical function to estimate within each categorical bin. errorbar : string, (string, number) tuple, callable or None Name of errorbar method (either "ci", "pi", "se", or "sd"), or a tuple with a method name and a level parameter, or a function that maps from a vector to a (min, max) interval, or None to hide errorbar. See the :doc:`errorbar tutorial ` for more information. .. versionadded:: v0.12.0 n_boot : int Number of bootstrap samples used to compute confidence intervals. seed : int, `numpy.random.Generator`, or `numpy.random.RandomState` Seed or random number generator for reproducible bootstrapping. units : name of variable in `data` or vector data Identifier of sampling units; used by the errorbar function to perform a multilevel bootstrap and account for repeated measures weights : name of variable in `data` or vector data Data values or column used to compute weighted statistics. Note that the use of weights may limit other statistical options. .. versionadded:: v0.13.1\ """), ci=dedent("""\ ci : float Level of the confidence interval to show, in [0, 100]. .. deprecated:: v0.12.0 Use `errorbar=("ci", ...)`.\ """), orient=dedent("""\ orient : "v" | "h" | "x" | "y" Orientation of the plot (vertical or horizontal). This is usually inferred based on the type of the input variables, but it can be used to resolve ambiguity when both `x` and `y` are numeric or when plotting wide-form data. .. versionchanged:: v0.13.0 Added 'x'/'y' as options, equivalent to 'v'/'h'.\ """), color=dedent("""\ color : matplotlib color Single color for the elements in the plot.\ """), palette=dedent("""\ palette : palette name, list, dict, or :class:`matplotlib.colors.Colormap` Color palette that maps the hue variable. If the palette is a dictionary, keys should be names of levels and values should be matplotlib colors. The type/value will sometimes force a qualitative/quantitative mapping.\ """), hue_norm=dedent("""\ hue_norm : tuple or :class:`matplotlib.colors.Normalize` object Normalization in data units for colormap applied to the `hue` variable when it is numeric. Not relevant if `hue` is categorical. .. versionadded:: v0.12.0\ """), saturation=dedent("""\ saturation : float Proportion of the original saturation to draw fill colors in. Large patches often look better with desaturated colors, but set this to `1` if you want the colors to perfectly match the input values.\ """), capsize=dedent("""\ capsize : float Width of the "caps" on error bars, relative to bar spacing.\ """), errcolor=dedent("""\ errcolor : matplotlib color Color used for the error bar lines. .. deprecated:: 0.13.0 Use `err_kws={'color': ...}`.\ """), errwidth=dedent("""\ errwidth : float Thickness of error bar lines (and caps), in points. .. deprecated:: 0.13.0 Use `err_kws={'linewidth': ...}`.\ """), fill=dedent("""\ fill : bool If True, use a solid patch. Otherwise, draw as line art. .. versionadded:: v0.13.0\ """), gap=dedent("""\ gap : float Shrink on the orient axis by this factor to add a gap between dodged elements. .. versionadded:: 0.13.0\ """), width=dedent("""\ width : float Width allotted to each element on the orient axis. When `native_scale=True`, it is relative to the minimum distance between two values in the native scale.\ """), dodge=dedent("""\ dodge : "auto" or bool When hue mapping is used, whether elements should be narrowed and shifted along the orient axis to eliminate overlap. If `"auto"`, set to `True` when the orient variable is crossed with the categorical variable or `False` otherwise. .. versionchanged:: 0.13.0 Added `"auto"` mode as a new default.\ """), linewidth=dedent("""\ linewidth : float Width of the lines that frame the plot elements.\ """), linecolor=dedent("""\ linecolor : color Color to use for line elements, when `fill` is True. .. versionadded:: v0.13.0\ """), log_scale=dedent("""\ log_scale : bool or number, or pair of bools or numbers Set axis scale(s) to log. A single value sets the data axis for any numeric axes in the plot. A pair of values sets each axis independently. Numeric values are interpreted as the desired base (default 10). When `None` or `False`, seaborn defers to the existing Axes scale. .. versionadded:: v0.13.0\ """), native_scale=dedent("""\ native_scale : bool When True, numeric or datetime values on the categorical axis will maintain their original scaling rather than being converted to fixed indices. .. versionadded:: v0.13.0\ """), formatter=dedent("""\ formatter : callable Function for converting categorical data into strings. Affects both grouping and tick labels. .. versionadded:: v0.13.0\ """), legend=dedent("""\ legend : "auto", "brief", "full", or False How to draw the legend. If "brief", numeric `hue` and `size` variables will be represented with a sample of evenly spaced values. If "full", every group will get an entry in the legend. If "auto", choose between brief or full representation based on number of levels. If `False`, no legend data is added and no legend is drawn. .. versionadded:: v0.13.0\ """), err_kws=dedent("""\ err_kws : dict Parameters of :class:`matplotlib.lines.Line2D`, for the error bar artists. .. versionadded:: v0.13.0\ """), ax_in=dedent("""\ ax : matplotlib Axes Axes object to draw the plot onto, otherwise uses the current Axes.\ """), ax_out=dedent("""\ ax : matplotlib Axes Returns the Axes object with the plot drawn onto it.\ """), # Shared see also boxplot=dedent("""\ boxplot : A traditional box-and-whisker plot with a similar API.\ """), violinplot=dedent("""\ violinplot : A combination of boxplot and kernel density estimation.\ """), stripplot=dedent("""\ stripplot : A scatterplot where one variable is categorical. Can be used in conjunction with other plots to show each observation.\ """), swarmplot=dedent("""\ swarmplot : A categorical scatterplot where the points do not overlap. Can be used with other plots to show each observation.\ """), barplot=dedent("""\ barplot : Show point estimates and confidence intervals using bars.\ """), countplot=dedent("""\ countplot : Show the counts of observations in each categorical bin.\ """), pointplot=dedent("""\ pointplot : Show point estimates and confidence intervals using dots.\ """), catplot=dedent("""\ catplot : Combine a categorical plot with a :class:`FacetGrid`.\ """), boxenplot=dedent("""\ boxenplot : An enhanced boxplot for larger datasets.\ """), ) _categorical_docs.update(_facet_docs) def boxplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, orient=None, color=None, palette=None, saturation=.75, fill=True, dodge="auto", width=.8, gap=0, whis=1.5, linecolor="auto", linewidth=None, fliersize=None, hue_norm=None, native_scale=False, log_scale=None, formatter=None, legend="auto", ax=None, **kwargs ): p = _CategoricalPlotter( data=data, variables=dict(x=x, y=y, hue=hue), order=order, orient=orient, color=color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if dodge == "auto": # Needs to be before scale_categorical changes the coordinate series dtype dodge = p._dodge_needed() if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) saturation = saturation if fill else 1 p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation) color = _default_color( ax.fill_between, hue, color, {k: v for k, v in kwargs.items() if k in ["c", "color", "fc", "facecolor"]}, saturation=saturation, ) linecolor = p._complement_color(linecolor, color, p._hue_map) p.plot_boxes( width=width, dodge=dodge, gap=gap, fill=fill, whis=whis, color=color, linecolor=linecolor, linewidth=linewidth, fliersize=fliersize, plot_kws=kwargs, ) p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax boxplot.__doc__ = dedent("""\ Draw a box plot to show distributions with respect to categories. A box plot (or box-and-whisker plot) shows the distribution of quantitative data in a way that facilitates comparisons between variables or across levels of a categorical variable. The box shows the quartiles of the dataset while the whiskers extend to show the rest of the distribution, except for points that are determined to be "outliers" using a method that is a function of the inter-quartile range. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} {orient} {color} {palette} {saturation} {fill} {dodge} {width} {gap} whis : float or pair of floats Paramater that controls whisker length. If scalar, whiskers are drawn to the farthest datapoint within *whis * IQR* from the nearest hinge. If a tuple, it is interpreted as percentiles that whiskers represent. {linecolor} {linewidth} fliersize : float Size of the markers used to indicate outlier observations. {hue_norm} {log_scale} {native_scale} {formatter} {legend} {ax_in} kwargs : key, value mappings Other keyword arguments are passed through to :meth:`matplotlib.axes.Axes.boxplot`. Returns ------- {ax_out} See Also -------- {violinplot} {stripplot} {swarmplot} {catplot} Examples -------- .. include:: ../docstrings/boxplot.rst """).format(**_categorical_docs) def violinplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, orient=None, color=None, palette=None, saturation=.75, fill=True, inner="box", split=False, width=.8, dodge="auto", gap=0, linewidth=None, linecolor="auto", cut=2, gridsize=100, bw_method="scott", bw_adjust=1, density_norm="area", common_norm=False, hue_norm=None, formatter=None, log_scale=None, native_scale=False, legend="auto", scale=deprecated, scale_hue=deprecated, bw=deprecated, inner_kws=None, ax=None, **kwargs, ): p = _CategoricalPlotter( data=data, variables=dict(x=x, y=y, hue=hue), order=order, orient=orient, color=color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if dodge == "auto": # Needs to be before scale_categorical changes the coordinate series dtype dodge = p._dodge_needed() if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) saturation = saturation if fill else 1 p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation) color = _default_color( ax.fill_between, hue, color, {k: v for k, v in kwargs.items() if k in ["c", "color", "fc", "facecolor"]}, saturation=saturation, ) linecolor = p._complement_color(linecolor, color, p._hue_map) density_norm, common_norm = p._violin_scale_backcompat( scale, scale_hue, density_norm, common_norm, ) bw_method = p._violin_bw_backcompat(bw, bw_method) kde_kws = dict(cut=cut, gridsize=gridsize, bw_method=bw_method, bw_adjust=bw_adjust) inner_kws = {} if inner_kws is None else inner_kws.copy() p.plot_violins( width=width, dodge=dodge, gap=gap, split=split, color=color, fill=fill, linecolor=linecolor, linewidth=linewidth, inner=inner, density_norm=density_norm, common_norm=common_norm, kde_kws=kde_kws, inner_kws=inner_kws, plot_kws=kwargs, ) p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax violinplot.__doc__ = dedent("""\ Draw a patch representing a KDE and add observations or box plot statistics. A violin plot plays a similar role as a box-and-whisker plot. It shows the distribution of data points after grouping by one (or more) variables. Unlike a box plot, each violin is drawn using a kernel density estimate of the underlying distribution. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} {orient} {color} {palette} {saturation} {fill} inner : {{"box", "quart", "point", "stick", None}} Representation of the data in the violin interior. One of the following: - `"box"`: draw a miniature box-and-whisker plot - `"quart"`: show the quartiles of the data - `"point"` or `"stick"`: show each observation split : bool Show an un-mirrored distribution, alternating sides when using `hue`. .. versionchanged:: v0.13.0 Previously, this option required a `hue` variable with exactly two levels. {width} {dodge} {gap} {linewidth} {linecolor} cut : float Distance, in units of bandwidth, to extend the density past extreme datapoints. Set to 0 to limit the violin within the data range. gridsize : int Number of points in the discrete grid used to evaluate the KDE. bw_method : {{"scott", "silverman", float}} Either the name of a reference rule or the scale factor to use when computing the kernel bandwidth. The actual kernel size will be determined by multiplying the scale factor by the standard deviation of the data within each group. .. versionadded:: v0.13.0 bw_adjust : float Factor that scales the bandwidth to use more or less smoothing. .. versionadded:: v0.13.0 density_norm : {{"area", "count", "width"}} Method that normalizes each density to determine the violin's width. If `area`, each violin will have the same area. If `count`, the width will be proportional to the number of observations. If `width`, each violin will have the same width. .. versionadded:: v0.13.0 common_norm : bool When `True`, normalize the density across all violins. .. versionadded:: v0.13.0 {hue_norm} {formatter} {log_scale} {native_scale} {legend} scale : {{"area", "count", "width"}} .. deprecated:: v0.13.0 See `density_norm`. scale_hue : bool .. deprecated:: v0.13.0 See `common_norm`. bw : {{'scott', 'silverman', float}} .. deprecated:: v0.13.0 See `bw_method` and `bw_adjust`. inner_kws : dict of key, value mappings Keyword arguments for the "inner" plot, passed to one of: - :class:`matplotlib.collections.LineCollection` (with `inner="stick"`) - :meth:`matplotlib.axes.Axes.scatter` (with `inner="point"`) - :meth:`matplotlib.axes.Axes.plot` (with `inner="quart"` or `inner="box"`) Additionally, with `inner="box"`, the keywords `box_width`, `whis_width`, and `marker` receive special handling for the components of the "box" plot. .. versionadded:: v0.13.0 {ax_in} kwargs : key, value mappings Keyword arguments for the violin patches, passsed through to :meth:`matplotlib.axes.Axes.fill_between`. Returns ------- {ax_out} See Also -------- {boxplot} {stripplot} {swarmplot} {catplot} Examples -------- .. include:: ../docstrings/violinplot.rst """).format(**_categorical_docs) def boxenplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, orient=None, color=None, palette=None, saturation=.75, fill=True, dodge="auto", width=.8, gap=0, linewidth=None, linecolor=None, width_method="exponential", k_depth="tukey", outlier_prop=0.007, trust_alpha=0.05, showfliers=True, hue_norm=None, log_scale=None, native_scale=False, formatter=None, legend="auto", scale=deprecated, box_kws=None, flier_kws=None, line_kws=None, ax=None, **kwargs, ): p = _CategoricalPlotter( data=data, variables=dict(x=x, y=y, hue=hue), order=order, orient=orient, color=color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if dodge == "auto": # Needs to be before scale_categorical changes the coordinate series dtype dodge = p._dodge_needed() if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) # Longer-term deprecations width_method = p._boxen_scale_backcompat(scale, width_method) saturation = saturation if fill else 1 p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation) color = _default_color( ax.fill_between, hue, color, {}, # TODO how to get default color? # {k: v for k, v in kwargs.items() if k in ["c", "color", "fc", "facecolor"]}, saturation=saturation, ) linecolor = p._complement_color(linecolor, color, p._hue_map) p.plot_boxens( width=width, dodge=dodge, gap=gap, fill=fill, color=color, linecolor=linecolor, linewidth=linewidth, width_method=width_method, k_depth=k_depth, outlier_prop=outlier_prop, trust_alpha=trust_alpha, showfliers=showfliers, box_kws=box_kws, flier_kws=flier_kws, line_kws=line_kws, plot_kws=kwargs, ) p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax boxenplot.__doc__ = dedent("""\ Draw an enhanced box plot for larger datasets. This style of plot was originally named a "letter value" plot because it shows a large number of quantiles that are defined as "letter values". It is similar to a box plot in plotting a nonparametric representation of a distribution in which all features correspond to actual observations. By plotting more quantiles, it provides more information about the shape of the distribution, particularly in the tails. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} {orient} {color} {palette} {saturation} {fill} {dodge} {width} {gap} {linewidth} {linecolor} width_method : {{"exponential", "linear", "area"}} Method to use for the width of the letter value boxes: - `"exponential"`: Represent the corresponding percentile - `"linear"`: Decrease by a constant amount for each box - `"area"`: Represent the density of data points in that box k_depth : {{"tukey", "proportion", "trustworthy", "full"}} or int The number of levels to compute and draw in each tail: - `"tukey"`: Use log2(n) - 3 levels, covering similar range as boxplot whiskers - `"proportion"`: Leave approximately `outlier_prop` fliers - `"trustworthy"`: Extend to level with confidence of at least `trust_alpha` - `"full"`: Use log2(n) + 1 levels and extend to most extreme points outlier_prop : float Proportion of data expected to be outliers; used when `k_depth="proportion"`. trust_alpha : float Confidence threshold for most extreme level; used when `k_depth="trustworthy"`. showfliers : bool If False, suppress the plotting of outliers. {hue_norm} {log_scale} {native_scale} {formatter} {legend} box_kws : dict Keyword arguments for the box artists; passed to :class:`matplotlib.patches.Rectangle`. .. versionadded:: v0.12.0 line_kws : dict Keyword arguments for the line denoting the median; passed to :meth:`matplotlib.axes.Axes.plot`. .. versionadded:: v0.12.0 flier_kws : dict Keyword arguments for the scatter denoting the outlier observations; passed to :meth:`matplotlib.axes.Axes.scatter`. .. versionadded:: v0.12.0 {ax_in} kwargs : key, value mappings Other keyword arguments are passed to :class:`matplotlib.patches.Rectangle`, superceded by those in `box_kws`. Returns ------- {ax_out} See Also -------- {violinplot} {boxplot} {catplot} Notes ----- For a more extensive explanation, you can read the paper that introduced the plot: https://vita.had.co.nz/papers/letter-value-plot.html Examples -------- .. include:: ../docstrings/boxenplot.rst """).format(**_categorical_docs) def stripplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, jitter=True, dodge=False, orient=None, color=None, palette=None, size=5, edgecolor=default, linewidth=0, hue_norm=None, log_scale=None, native_scale=False, formatter=None, legend="auto", ax=None, **kwargs ): p = _CategoricalPlotter( data=data, variables=dict(x=x, y=y, hue=hue), order=order, orient=orient, color=color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) color = _default_color(ax.scatter, hue, color, kwargs) edgecolor = p._complement_color(edgecolor, color, p._hue_map) kwargs.setdefault("zorder", 3) size = kwargs.get("s", size) kwargs.update( s=size ** 2, edgecolor=edgecolor, linewidth=linewidth, ) p.plot_strips( jitter=jitter, dodge=dodge, color=color, plot_kws=kwargs, ) # XXX this happens inside a plotting method in the distribution plots # but maybe it's better out here? Alternatively, we have an open issue # suggesting that _attach could add default axes labels, which seems smart. p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax stripplot.__doc__ = dedent("""\ Draw a categorical scatterplot using jitter to reduce overplotting. A strip plot can be drawn on its own, but it is also a good complement to a box or violin plot in cases where you want to show all observations along with some representation of the underlying distribution. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} jitter : float, `True`/`1` is special-cased Amount of jitter (only along the categorical axis) to apply. This can be useful when you have many points and they overlap, so that it is easier to see the distribution. You can specify the amount of jitter (half the width of the uniform random variable support), or use `True` for a good default. dodge : bool When a `hue` variable is assigned, setting this to `True` will separate the strips for different hue levels along the categorical axis and narrow the amount of space allotedto each strip. Otherwise, the points for each level will be plotted in the same strip. {orient} {color} {palette} size : float Radius of the markers, in points. edgecolor : matplotlib color, "gray" is special-cased Color of the lines around each point. If you pass `"gray"`, the brightness is determined by the color palette used for the body of the points. Note that `stripplot` has `linewidth=0` by default, so edge colors are only visible with nonzero line width. {linewidth} {hue_norm} {log_scale} {native_scale} {formatter} {legend} {ax_in} kwargs : key, value mappings Other keyword arguments are passed through to :meth:`matplotlib.axes.Axes.scatter`. Returns ------- {ax_out} See Also -------- {swarmplot} {boxplot} {violinplot} {catplot} Examples -------- .. include:: ../docstrings/stripplot.rst """).format(**_categorical_docs) def swarmplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, dodge=False, orient=None, color=None, palette=None, size=5, edgecolor=None, linewidth=0, hue_norm=None, log_scale=None, native_scale=False, formatter=None, legend="auto", warn_thresh=.05, ax=None, **kwargs ): p = _CategoricalPlotter( data=data, variables=dict(x=x, y=y, hue=hue), order=order, orient=orient, color=color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) if not p.has_xy_data: return ax # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) color = _default_color(ax.scatter, hue, color, kwargs) edgecolor = p._complement_color(edgecolor, color, p._hue_map) kwargs.setdefault("zorder", 3) size = kwargs.get("s", size) if linewidth is None: linewidth = size / 10 kwargs.update(dict( s=size ** 2, edgecolor=edgecolor, linewidth=linewidth, )) p.plot_swarms( dodge=dodge, color=color, warn_thresh=warn_thresh, plot_kws=kwargs, ) p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax swarmplot.__doc__ = dedent("""\ Draw a categorical scatterplot with points adjusted to be non-overlapping. This function is similar to :func:`stripplot`, but the points are adjusted (only along the categorical axis) so that they don't overlap. This gives a better representation of the distribution of values, but it does not scale well to large numbers of observations. This style of plot is sometimes called a "beeswarm". A swarm plot can be drawn on its own, but it is also a good complement to a box or violin plot in cases where you want to show all observations along with some representation of the underlying distribution. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} dodge : bool When a `hue` variable is assigned, setting this to `True` will separate the swarms for different hue levels along the categorical axis and narrow the amount of space allotedto each strip. Otherwise, the points for each level will be plotted in the same swarm. {orient} {color} {palette} size : float Radius of the markers, in points. edgecolor : matplotlib color, "gray" is special-cased Color of the lines around each point. If you pass `"gray"`, the brightness is determined by the color palette used for the body of the points. {linewidth} {hue_norm} {log_scale} {native_scale} {formatter} {legend} warn_thresh : float The proportion of points that must overlap to trigger a warning. {ax_in} kwargs : key, value mappings Other keyword arguments are passed through to :meth:`matplotlib.axes.Axes.scatter`. Returns ------- {ax_out} See Also -------- {boxplot} {violinplot} {stripplot} {catplot} Examples -------- .. include:: ../docstrings/swarmplot.rst """).format(**_categorical_docs) def barplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None, weights=None, orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None, width=.8, dodge="auto", gap=0, log_scale=None, native_scale=False, formatter=None, legend="auto", capsize=0, err_kws=None, ci=deprecated, errcolor=deprecated, errwidth=deprecated, ax=None, **kwargs, ): errorbar = utils._deprecate_ci(errorbar, ci) # Be backwards compatible with len passed directly, which # does not work in Series.agg (maybe a pandas bug?) if estimator is len: estimator = "size" p = _CategoricalAggPlotter( data=data, variables=dict(x=x, y=y, hue=hue, units=units, weight=weights), order=order, orient=orient, color=color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if dodge == "auto": # Needs to be before scale_categorical changes the coordinate series dtype dodge = p._dodge_needed() if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) saturation = saturation if fill else 1 p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation) color = _default_color(ax.bar, hue, color, kwargs, saturation=saturation) agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D) # Deprecations to remove in v0.15.0. err_kws, capsize = p._err_kws_backcompat(err_kws, errcolor, errwidth, capsize) p.plot_bars( aggregator=aggregator, dodge=dodge, width=width, gap=gap, color=color, fill=fill, capsize=capsize, err_kws=err_kws, plot_kws=kwargs, ) p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax barplot.__doc__ = dedent("""\ Show point estimates and errors as rectangular bars. A bar plot represents an aggregate or statistical estimate for a numeric variable with the height of each rectangle and indicates the uncertainty around that estimate using an error bar. Bar plots include 0 in the axis range, and they are a good choice when 0 is a meaningful value for the variable to take. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} {stat_api_params} {orient} {color} {palette} {saturation} {fill} {hue_norm} {width} {dodge} {gap} {log_scale} {native_scale} {formatter} {legend} {capsize} {err_kws} {ci} {errcolor} {errwidth} {ax_in} kwargs : key, value mappings Other parameters are passed through to :class:`matplotlib.patches.Rectangle`. Returns ------- {ax_out} See Also -------- {countplot} {pointplot} {catplot} Notes ----- For datasets where 0 is not a meaningful value, a :func:`pointplot` will allow you to focus on differences between levels of one or more categorical variables. It is also important to keep in mind that a bar plot shows only the mean (or other aggregate) value, but it is often more informative to show the distribution of values at each level of the categorical variables. In those cases, approaches such as a :func:`boxplot` or :func:`violinplot` may be more appropriate. Examples -------- .. include:: ../docstrings/barplot.rst """).format(**_categorical_docs) def pointplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None, weights=None, color=None, palette=None, hue_norm=None, markers=default, linestyles=default, dodge=False, log_scale=None, native_scale=False, orient=None, capsize=0, formatter=None, legend="auto", err_kws=None, ci=deprecated, errwidth=deprecated, join=deprecated, scale=deprecated, ax=None, **kwargs, ): errorbar = utils._deprecate_ci(errorbar, ci) p = _CategoricalAggPlotter( data=data, variables=dict(x=x, y=y, hue=hue, units=units, weight=weights), order=order, orient=orient, # Handle special backwards compatibility where pointplot originally # did *not* default to multi-colored unless a palette was specified. color="C0" if (color is None and palette is None) else color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) color = _default_color(ax.plot, hue, color, kwargs) agg_cls = WeightedAggregator if "weight" in p.plot_data else EstimateAggregator aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) err_kws = {} if err_kws is None else normalize_kwargs(err_kws, mpl.lines.Line2D) # Deprecations to remove in v0.15.0. p._point_kwargs_backcompat(scale, join, kwargs) err_kws, capsize = p._err_kws_backcompat(err_kws, None, errwidth, capsize) p.plot_points( aggregator=aggregator, markers=markers, linestyles=linestyles, dodge=dodge, color=color, capsize=capsize, err_kws=err_kws, plot_kws=kwargs, ) p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax pointplot.__doc__ = dedent("""\ Show point estimates and errors using lines with markers. A point plot represents an estimate of central tendency for a numeric variable by the position of the dot and provides some indication of the uncertainty around that estimate using error bars. Point plots can be more useful than bar plots for focusing comparisons between different levels of one or more categorical variables. They are particularly adept at showing interactions: how the relationship between levels of one categorical variable changes across levels of a second categorical variable. The lines that join each point from the same `hue` level allow interactions to be judged by differences in slope, which is easier for the eyes than comparing the heights of several groups of points or bars. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} {stat_api_params} {color} {palette} {hue_norm} markers : string or list of strings Markers to use for each of the `hue` levels. linestyles : string or list of strings Line styles to use for each of the `hue` levels. dodge : bool or float Amount to separate the points for each level of the `hue` variable along the categorical axis. Setting to `True` will apply a small default. {log_scale} {native_scale} {orient} {capsize} {formatter} {legend} {err_kws} {ci} {errwidth} join : bool If `True`, connect point estimates with a line. .. deprecated:: v0.13.0 Set `linestyle="none"` to remove the lines between the points. scale : float Scale factor for the plot elements. .. deprecated:: v0.13.0 Control element sizes with :class:`matplotlib.lines.Line2D` parameters. {ax_in} kwargs : key, value mappings Other parameters are passed through to :class:`matplotlib.lines.Line2D`. .. versionadded:: v0.13.0 Returns ------- {ax_out} See Also -------- {barplot} {catplot} Notes ----- It is important to keep in mind that a point plot shows only the mean (or other estimator) value, but in many cases it may be more informative to show the distribution of values at each level of the categorical variables. In that case, other approaches such as a box or violin plot may be more appropriate. Examples -------- .. include:: ../docstrings/pointplot.rst """).format(**_categorical_docs) def countplot( data=None, *, x=None, y=None, hue=None, order=None, hue_order=None, orient=None, color=None, palette=None, saturation=.75, fill=True, hue_norm=None, stat="count", width=.8, dodge="auto", gap=0, log_scale=None, native_scale=False, formatter=None, legend="auto", ax=None, **kwargs ): if x is None and y is not None: orient = "y" x = 1 if list(y) else None elif x is not None and y is None: orient = "x" y = 1 if list(x) else None elif x is not None and y is not None: raise TypeError("Cannot pass values for both `x` and `y`.") p = _CategoricalAggPlotter( data=data, variables=dict(x=x, y=y, hue=hue), order=order, orient=orient, color=color, legend=legend, ) if ax is None: ax = plt.gca() if p.plot_data.empty: return ax if dodge == "auto": # Needs to be before scale_categorical changes the coordinate series dtype dodge = p._dodge_needed() if p.var_types.get(p.orient) == "categorical" or not native_scale: p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(ax, log_scale=log_scale) # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) saturation = saturation if fill else 1 p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation) color = _default_color(ax.bar, hue, color, kwargs, saturation) count_axis = {"x": "y", "y": "x"}[p.orient] if p.input_format == "wide": p.plot_data[count_axis] = 1 _check_argument("stat", ["count", "percent", "probability", "proportion"], stat) p.variables[count_axis] = stat if stat != "count": denom = 100 if stat == "percent" else 1 p.plot_data[count_axis] /= len(p.plot_data) / denom aggregator = EstimateAggregator("sum", errorbar=None) p.plot_bars( aggregator=aggregator, dodge=dodge, width=width, gap=gap, color=color, fill=fill, capsize=0, err_kws={}, plot_kws=kwargs, ) p._add_axis_labels(ax) p._adjust_cat_axis(ax, axis=p.orient) return ax countplot.__doc__ = dedent("""\ Show the counts of observations in each categorical bin using bars. A count plot can be thought of as a histogram across a categorical, instead of quantitative, variable. The basic API and options are identical to those for :func:`barplot`, so you can compare counts across nested variables. Note that :func:`histplot` function offers similar functionality with additional features (e.g. bar stacking), although its default behavior is somewhat different. {categorical_narrative} Parameters ---------- {categorical_data} {input_params} {order_vars} {orient} {color} {palette} {saturation} {fill} {hue_norm} stat : {{'count', 'percent', 'proportion', 'probability'}} Statistic to compute; when not `'count'`, bar heights will be normalized so that they sum to 100 (for `'percent'`) or 1 (otherwise) across the plot. .. versionadded:: v0.13.0 {width} {dodge} {gap} {log_scale} {native_scale} {formatter} {legend} {ax_in} kwargs : key, value mappings Other parameters are passed through to :class:`matplotlib.patches.Rectangle`. Returns ------- {ax_out} See Also -------- histplot : Bin and count observations with additional options. {barplot} {catplot} Examples -------- .. include:: ../docstrings/countplot.rst """).format(**_categorical_docs) def catplot( data=None, *, x=None, y=None, hue=None, row=None, col=None, kind="strip", estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None, units=None, weights=None, order=None, hue_order=None, row_order=None, col_order=None, col_wrap=None, height=5, aspect=1, log_scale=None, native_scale=False, formatter=None, orient=None, color=None, palette=None, hue_norm=None, legend="auto", legend_out=True, sharex=True, sharey=True, margin_titles=False, facet_kws=None, ci=deprecated, **kwargs ): # Check for attempt to plot onto specific axes and warn if "ax" in kwargs: msg = ("catplot is a figure-level function and does not accept " f"target axes. You may wish to try {kind}plot") warnings.warn(msg, UserWarning) kwargs.pop("ax") desaturated_kinds = ["bar", "count", "box", "violin", "boxen"] undodged_kinds = ["strip", "swarm", "point"] if kind in ["bar", "point", "count"]: Plotter = _CategoricalAggPlotter else: Plotter = _CategoricalPlotter if kind == "count": if x is None and y is not None: orient = "y" x = 1 elif x is not None and y is None: orient = "x" y = 1 elif x is not None and y is not None: raise ValueError("Cannot pass values for both `x` and `y`.") p = Plotter( data=data, variables=dict( x=x, y=y, hue=hue, row=row, col=col, units=units, weight=weights ), order=order, orient=orient, # Handle special backwards compatibility where pointplot originally # did *not* default to multi-colored unless a palette was specified. color="C0" if kind == "point" and palette is None and color is None else color, legend=legend, ) for var in ["row", "col"]: # Handle faceting variables that lack name information if var in p.variables and p.variables[var] is None: p.variables[var] = f"_{var}_" # Adapt the plot_data dataframe for use with FacetGrid facet_data = p.plot_data.rename(columns=p.variables) facet_data = facet_data.loc[:, ~facet_data.columns.duplicated()] col_name = p.variables.get("col", None) row_name = p.variables.get("row", None) if facet_kws is None: facet_kws = {} g = FacetGrid( data=facet_data, row=row_name, col=col_name, col_wrap=col_wrap, row_order=row_order, col_order=col_order, sharex=sharex, sharey=sharey, legend_out=legend_out, margin_titles=margin_titles, height=height, aspect=aspect, **facet_kws, ) # Capture this here because scale_categorical is going to insert a (null) # x variable even if it is empty. It's not clear whether that needs to # happen or if disabling that is the cleaner solution. has_xy_data = p.has_xy_data if not native_scale or p.var_types[p.orient] == "categorical": p.scale_categorical(p.orient, order=order, formatter=formatter) p._attach(g, log_scale=log_scale) if not has_xy_data: return g # Deprecations to remove in v0.14.0. hue_order = p._palette_without_hue_backcompat(palette, hue_order) palette, hue_order = p._hue_backcompat(color, palette, hue_order) # Othe deprecations errorbar = utils._deprecate_ci(errorbar, ci) saturation = kwargs.pop( "saturation", 0.75 if kind in desaturated_kinds and kwargs.get("fill", True) else 1 ) p.map_hue(palette=palette, order=hue_order, norm=hue_norm, saturation=saturation) # Set a default color # Otherwise each artist will be plotted separately and trip the color cycle if hue is None: color = "C0" if color is None else color if saturation < 1: color = desaturate(color, saturation) if kind in ["strip", "swarm"]: kwargs = normalize_kwargs(kwargs, mpl.collections.PathCollection) kwargs["edgecolor"] = p._complement_color( kwargs.pop("edgecolor", default), color, p._hue_map ) width = kwargs.pop("width", 0.8) dodge = kwargs.pop("dodge", False if kind in undodged_kinds else "auto") if dodge == "auto": dodge = p._dodge_needed() if "weight" in p.plot_data: if kind not in ["bar", "point"]: msg = f"The `weights` parameter has no effect with kind={kind!r}." warnings.warn(msg, stacklevel=2) agg_cls = WeightedAggregator else: agg_cls = EstimateAggregator if kind == "strip": jitter = kwargs.pop("jitter", True) plot_kws = kwargs.copy() plot_kws.setdefault("zorder", 3) plot_kws.setdefault("linewidth", 0) if "s" not in plot_kws: plot_kws["s"] = plot_kws.pop("size", 5) ** 2 p.plot_strips( jitter=jitter, dodge=dodge, color=color, plot_kws=plot_kws, ) elif kind == "swarm": warn_thresh = kwargs.pop("warn_thresh", .05) plot_kws = kwargs.copy() plot_kws.setdefault("zorder", 3) if "s" not in plot_kws: plot_kws["s"] = plot_kws.pop("size", 5) ** 2 if plot_kws.setdefault("linewidth", 0) is None: plot_kws["linewidth"] = np.sqrt(plot_kws["s"]) / 10 p.plot_swarms( dodge=dodge, color=color, warn_thresh=warn_thresh, plot_kws=plot_kws, ) elif kind == "box": plot_kws = kwargs.copy() gap = plot_kws.pop("gap", 0) fill = plot_kws.pop("fill", True) whis = plot_kws.pop("whis", 1.5) linewidth = plot_kws.pop("linewidth", None) fliersize = plot_kws.pop("fliersize", 5) linecolor = p._complement_color( plot_kws.pop("linecolor", "auto"), color, p._hue_map ) p.plot_boxes( width=width, dodge=dodge, gap=gap, fill=fill, whis=whis, color=color, linecolor=linecolor, linewidth=linewidth, fliersize=fliersize, plot_kws=plot_kws, ) elif kind == "violin": plot_kws = kwargs.copy() gap = plot_kws.pop("gap", 0) fill = plot_kws.pop("fill", True) split = plot_kws.pop("split", False) inner = plot_kws.pop("inner", "box") density_norm = plot_kws.pop("density_norm", "area") common_norm = plot_kws.pop("common_norm", False) scale = plot_kws.pop("scale", deprecated) scale_hue = plot_kws.pop("scale_hue", deprecated) density_norm, common_norm = p._violin_scale_backcompat( scale, scale_hue, density_norm, common_norm, ) bw_method = p._violin_bw_backcompat( plot_kws.pop("bw", deprecated), plot_kws.pop("bw_method", "scott") ) kde_kws = dict( cut=plot_kws.pop("cut", 2), gridsize=plot_kws.pop("gridsize", 100), bw_adjust=plot_kws.pop("bw_adjust", 1), bw_method=bw_method, ) inner_kws = plot_kws.pop("inner_kws", {}).copy() linewidth = plot_kws.pop("linewidth", None) linecolor = plot_kws.pop("linecolor", "auto") linecolor = p._complement_color(linecolor, color, p._hue_map) p.plot_violins( width=width, dodge=dodge, gap=gap, split=split, color=color, fill=fill, linecolor=linecolor, linewidth=linewidth, inner=inner, density_norm=density_norm, common_norm=common_norm, kde_kws=kde_kws, inner_kws=inner_kws, plot_kws=plot_kws, ) elif kind == "boxen": plot_kws = kwargs.copy() gap = plot_kws.pop("gap", 0) fill = plot_kws.pop("fill", True) linecolor = plot_kws.pop("linecolor", "auto") linewidth = plot_kws.pop("linewidth", None) k_depth = plot_kws.pop("k_depth", "tukey") width_method = plot_kws.pop("width_method", "exponential") outlier_prop = plot_kws.pop("outlier_prop", 0.007) trust_alpha = plot_kws.pop("trust_alpha", 0.05) showfliers = plot_kws.pop("showfliers", True) box_kws = plot_kws.pop("box_kws", {}) flier_kws = plot_kws.pop("flier_kws", {}) line_kws = plot_kws.pop("line_kws", {}) if "scale" in plot_kws: width_method = p._boxen_scale_backcompat( plot_kws["scale"], width_method ) linecolor = p._complement_color(linecolor, color, p._hue_map) p.plot_boxens( width=width, dodge=dodge, gap=gap, fill=fill, color=color, linecolor=linecolor, linewidth=linewidth, width_method=width_method, k_depth=k_depth, outlier_prop=outlier_prop, trust_alpha=trust_alpha, showfliers=showfliers, box_kws=box_kws, flier_kws=flier_kws, line_kws=line_kws, plot_kws=plot_kws, ) elif kind == "point": aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) markers = kwargs.pop("markers", default) linestyles = kwargs.pop("linestyles", default) # Deprecations to remove in v0.15.0. # TODO Uncomment when removing deprecation backcompat # capsize = kwargs.pop("capsize", 0) # err_kws = normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D) p._point_kwargs_backcompat( kwargs.pop("scale", deprecated), kwargs.pop("join", deprecated), kwargs ) err_kws, capsize = p._err_kws_backcompat( normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D), None, errwidth=kwargs.pop("errwidth", deprecated), capsize=kwargs.pop("capsize", 0), ) p.plot_points( aggregator=aggregator, markers=markers, linestyles=linestyles, dodge=dodge, color=color, capsize=capsize, err_kws=err_kws, plot_kws=kwargs, ) elif kind == "bar": aggregator = agg_cls(estimator, errorbar, n_boot=n_boot, seed=seed) err_kws, capsize = p._err_kws_backcompat( normalize_kwargs(kwargs.pop("err_kws", {}), mpl.lines.Line2D), errcolor=kwargs.pop("errcolor", deprecated), errwidth=kwargs.pop("errwidth", deprecated), capsize=kwargs.pop("capsize", 0), ) gap = kwargs.pop("gap", 0) fill = kwargs.pop("fill", True) p.plot_bars( aggregator=aggregator, dodge=dodge, width=width, gap=gap, color=color, fill=fill, capsize=capsize, err_kws=err_kws, plot_kws=kwargs, ) elif kind == "count": aggregator = EstimateAggregator("sum", errorbar=None) count_axis = {"x": "y", "y": "x"}[p.orient] p.plot_data[count_axis] = 1 stat_options = ["count", "percent", "probability", "proportion"] stat = _check_argument("stat", stat_options, kwargs.pop("stat", "count")) p.variables[count_axis] = stat if stat != "count": denom = 100 if stat == "percent" else 1 p.plot_data[count_axis] /= len(p.plot_data) / denom gap = kwargs.pop("gap", 0) fill = kwargs.pop("fill", True) p.plot_bars( aggregator=aggregator, dodge=dodge, width=width, gap=gap, color=color, fill=fill, capsize=0, err_kws={}, plot_kws=kwargs, ) else: msg = ( f"Invalid `kind`: {kind!r}. Options are 'strip', 'swarm', " "'box', 'boxen', 'violin', 'bar', 'count', and 'point'." ) raise ValueError(msg) for ax in g.axes.flat: p._adjust_cat_axis(ax, axis=p.orient) g.set_axis_labels(p.variables.get("x"), p.variables.get("y")) g.set_titles() g.tight_layout() for ax in g.axes.flat: g._update_legend_data(ax) ax.legend_ = None if legend == "auto": show_legend = not p._redundant_hue and p.input_format != "wide" else: show_legend = bool(legend) if show_legend: g.add_legend(title=p.variables.get("hue"), label_order=hue_order) if data is not None: # Replace the dataframe on the FacetGrid for any subsequent maps g.data = data return g catplot.__doc__ = dedent("""\ Figure-level interface for drawing categorical plots onto a FacetGrid. This function provides access to several axes-level functions that show the relationship between a numerical and one or more categorical variables using one of several visual representations. The `kind` parameter selects the underlying axes-level function to use. Categorical scatterplots: - :func:`stripplot` (with `kind="strip"`; the default) - :func:`swarmplot` (with `kind="swarm"`) Categorical distribution plots: - :func:`boxplot` (with `kind="box"`) - :func:`violinplot` (with `kind="violin"`) - :func:`boxenplot` (with `kind="boxen"`) Categorical estimate plots: - :func:`pointplot` (with `kind="point"`) - :func:`barplot` (with `kind="bar"`) - :func:`countplot` (with `kind="count"`) Extra keyword arguments are passed to the underlying function, so you should refer to the documentation for each to see kind-specific options. {categorical_narrative} After plotting, the :class:`FacetGrid` with the plot is returned and can be used directly to tweak supporting plot details or add other layers. Parameters ---------- {categorical_data} {input_params} row, col : names of variables in `data` or vector data Categorical variables that will determine the faceting of the grid. kind : str The kind of plot to draw, corresponds to the name of a categorical axes-level plotting function. Options are: "strip", "swarm", "box", "violin", "boxen", "point", "bar", or "count". {stat_api_params} {order_vars} row_order, col_order : lists of strings Order to organize the rows and/or columns of the grid in; otherwise the orders are inferred from the data objects. {col_wrap} {height} {aspect} {log_scale} {native_scale} {formatter} {orient} {color} {palette} {hue_norm} {legend} {legend_out} {share_xy} {margin_titles} facet_kws : dict Dictionary of other keyword arguments to pass to :class:`FacetGrid`. kwargs : key, value pairings Other keyword arguments are passed through to the underlying plotting function. Returns ------- :class:`FacetGrid` Returns the :class:`FacetGrid` object with the plot on it for further tweaking. Examples -------- .. include:: ../docstrings/catplot.rst """).format(**_categorical_docs) class Beeswarm: """Modifies a scatterplot artist to show a beeswarm plot.""" def __init__(self, orient="x", width=0.8, warn_thresh=.05): self.orient = orient self.width = width self.warn_thresh = warn_thresh def __call__(self, points, center): """Swarm `points`, a PathCollection, around the `center` position.""" # Convert from point size (area) to diameter ax = points.axes dpi = ax.figure.dpi # Get the original positions of the points orig_xy_data = points.get_offsets() # Reset the categorical positions to the center line cat_idx = 1 if self.orient == "y" else 0 orig_xy_data[:, cat_idx] = center # Transform the data coordinates to point coordinates. # We'll figure out the swarm positions in the latter # and then convert back to data coordinates and replot orig_x_data, orig_y_data = orig_xy_data.T orig_xy = ax.transData.transform(orig_xy_data) # Order the variables so that x is the categorical axis if self.orient == "y": orig_xy = orig_xy[:, [1, 0]] # Add a column with each point's radius sizes = points.get_sizes() if sizes.size == 1: sizes = np.repeat(sizes, orig_xy.shape[0]) edge = points.get_linewidth().item() radii = (np.sqrt(sizes) + edge) / 2 * (dpi / 72) orig_xy = np.c_[orig_xy, radii] # Sort along the value axis to facilitate the beeswarm sorter = np.argsort(orig_xy[:, 1]) orig_xyr = orig_xy[sorter] # Adjust points along the categorical axis to prevent overlaps new_xyr = np.empty_like(orig_xyr) new_xyr[sorter] = self.beeswarm(orig_xyr) # Transform the point coordinates back to data coordinates if self.orient == "y": new_xy = new_xyr[:, [1, 0]] else: new_xy = new_xyr[:, :2] new_x_data, new_y_data = ax.transData.inverted().transform(new_xy).T # Add gutters t_fwd, t_inv = _get_transform_functions(ax, self.orient) if self.orient == "y": self.add_gutters(new_y_data, center, t_fwd, t_inv) else: self.add_gutters(new_x_data, center, t_fwd, t_inv) # Reposition the points so they do not overlap if self.orient == "y": points.set_offsets(np.c_[orig_x_data, new_y_data]) else: points.set_offsets(np.c_[new_x_data, orig_y_data]) def beeswarm(self, orig_xyr): """Adjust x position of points to avoid overlaps.""" # In this method, `x` is always the categorical axis # Center of the swarm, in point coordinates midline = orig_xyr[0, 0] # Start the swarm with the first point swarm = np.atleast_2d(orig_xyr[0]) # Loop over the remaining points for xyr_i in orig_xyr[1:]: # Find the points in the swarm that could possibly # overlap with the point we are currently placing neighbors = self.could_overlap(xyr_i, swarm) # Find positions that would be valid individually # with respect to each of the swarm neighbors candidates = self.position_candidates(xyr_i, neighbors) # Sort candidates by their centrality offsets = np.abs(candidates[:, 0] - midline) candidates = candidates[np.argsort(offsets)] # Find the first candidate that does not overlap any neighbors new_xyr_i = self.first_non_overlapping_candidate(candidates, neighbors) # Place it into the swarm swarm = np.vstack([swarm, new_xyr_i]) return swarm def could_overlap(self, xyr_i, swarm): """Return a list of all swarm points that could overlap with target.""" # Because we work backwards through the swarm and can short-circuit, # the for-loop is faster than vectorization _, y_i, r_i = xyr_i neighbors = [] for xyr_j in reversed(swarm): _, y_j, r_j = xyr_j if (y_i - y_j) < (r_i + r_j): neighbors.append(xyr_j) else: break return np.array(neighbors)[::-1] def position_candidates(self, xyr_i, neighbors): """Return a list of coordinates that might be valid by adjusting x.""" candidates = [xyr_i] x_i, y_i, r_i = xyr_i left_first = True for x_j, y_j, r_j in neighbors: dy = y_i - y_j dx = np.sqrt(max((r_i + r_j) ** 2 - dy ** 2, 0)) * 1.05 cl, cr = (x_j - dx, y_i, r_i), (x_j + dx, y_i, r_i) if left_first: new_candidates = [cl, cr] else: new_candidates = [cr, cl] candidates.extend(new_candidates) left_first = not left_first return np.array(candidates) def first_non_overlapping_candidate(self, candidates, neighbors): """Find the first candidate that does not overlap with the swarm.""" # If we have no neighbors, all candidates are good. if len(neighbors) == 0: return candidates[0] neighbors_x = neighbors[:, 0] neighbors_y = neighbors[:, 1] neighbors_r = neighbors[:, 2] for xyr_i in candidates: x_i, y_i, r_i = xyr_i dx = neighbors_x - x_i dy = neighbors_y - y_i sq_distances = np.square(dx) + np.square(dy) sep_needed = np.square(neighbors_r + r_i) # Good candidate does not overlap any of neighbors which means that # squared distance between candidate and any of the neighbors has # to be at least square of the summed radii good_candidate = np.all(sq_distances >= sep_needed) if good_candidate: return xyr_i raise RuntimeError( "No non-overlapping candidates found. This should not happen." ) def add_gutters(self, points, center, trans_fwd, trans_inv): """Stop points from extending beyond their territory.""" half_width = self.width / 2 low_gutter = trans_inv(trans_fwd(center) - half_width) off_low = points < low_gutter if off_low.any(): points[off_low] = low_gutter high_gutter = trans_inv(trans_fwd(center) + half_width) off_high = points > high_gutter if off_high.any(): points[off_high] = high_gutter gutter_prop = (off_high + off_low).sum() / len(points) if gutter_prop > self.warn_thresh: msg = ( "{:.1%} of the points cannot be placed; you may want " "to decrease the size of the markers or use stripplot." ).format(gutter_prop) warnings.warn(msg, UserWarning) return points BoxPlotArtists = namedtuple("BoxPlotArtists", "box median whiskers caps fliers mean") class BoxPlotContainer: def __init__(self, artist_dict): self.boxes = artist_dict["boxes"] self.medians = artist_dict["medians"] self.whiskers = artist_dict["whiskers"] self.caps = artist_dict["caps"] self.fliers = artist_dict["fliers"] self.means = artist_dict["means"] self._label = None self._children = [ *self.boxes, *self.medians, *self.whiskers, *self.caps, *self.fliers, *self.means, ] def __repr__(self): return f"" def __getitem__(self, idx): pair_slice = slice(2 * idx, 2 * idx + 2) return BoxPlotArtists( self.boxes[idx] if self.boxes else [], self.medians[idx] if self.medians else [], self.whiskers[pair_slice] if self.whiskers else [], self.caps[pair_slice] if self.caps else [], self.fliers[idx] if self.fliers else [], self.means[idx]if self.means else [], ) def __iter__(self): yield from (self[i] for i in range(len(self.boxes))) def get_label(self): return self._label def set_label(self, value): self._label = value def get_children(self): return self._children def remove(self): for child in self._children: child.remove() ================================================ FILE: seaborn/cm.py ================================================ from matplotlib import colors from seaborn._compat import register_colormap _rocket_lut = [ [ 0.01060815, 0.01808215, 0.10018654], [ 0.01428972, 0.02048237, 0.10374486], [ 0.01831941, 0.0229766 , 0.10738511], [ 0.02275049, 0.02554464, 0.11108639], [ 0.02759119, 0.02818316, 0.11483751], [ 0.03285175, 0.03088792, 0.11863035], [ 0.03853466, 0.03365771, 0.12245873], [ 0.04447016, 0.03648425, 0.12631831], [ 0.05032105, 0.03936808, 0.13020508], [ 0.05611171, 0.04224835, 0.13411624], [ 0.0618531 , 0.04504866, 0.13804929], [ 0.06755457, 0.04778179, 0.14200206], [ 0.0732236 , 0.05045047, 0.14597263], [ 0.0788708 , 0.05305461, 0.14995981], [ 0.08450105, 0.05559631, 0.15396203], [ 0.09011319, 0.05808059, 0.15797687], [ 0.09572396, 0.06050127, 0.16200507], [ 0.10132312, 0.06286782, 0.16604287], [ 0.10692823, 0.06517224, 0.17009175], [ 0.1125315 , 0.06742194, 0.17414848], [ 0.11813947, 0.06961499, 0.17821272], [ 0.12375803, 0.07174938, 0.18228425], [ 0.12938228, 0.07383015, 0.18636053], [ 0.13501631, 0.07585609, 0.19044109], [ 0.14066867, 0.0778224 , 0.19452676], [ 0.14633406, 0.07973393, 0.1986151 ], [ 0.15201338, 0.08159108, 0.20270523], [ 0.15770877, 0.08339312, 0.20679668], [ 0.16342174, 0.0851396 , 0.21088893], [ 0.16915387, 0.08682996, 0.21498104], [ 0.17489524, 0.08848235, 0.2190294 ], [ 0.18065495, 0.09009031, 0.22303512], [ 0.18643324, 0.09165431, 0.22699705], [ 0.19223028, 0.09317479, 0.23091409], [ 0.19804623, 0.09465217, 0.23478512], [ 0.20388117, 0.09608689, 0.23860907], [ 0.20973515, 0.09747934, 0.24238489], [ 0.21560818, 0.09882993, 0.24611154], [ 0.22150014, 0.10013944, 0.2497868 ], [ 0.22741085, 0.10140876, 0.25340813], [ 0.23334047, 0.10263737, 0.25697736], [ 0.23928891, 0.10382562, 0.2604936 ], [ 0.24525608, 0.10497384, 0.26395596], [ 0.25124182, 0.10608236, 0.26736359], [ 0.25724602, 0.10715148, 0.27071569], [ 0.26326851, 0.1081815 , 0.27401148], [ 0.26930915, 0.1091727 , 0.2772502 ], [ 0.27536766, 0.11012568, 0.28043021], [ 0.28144375, 0.11104133, 0.2835489 ], [ 0.2875374 , 0.11191896, 0.28660853], [ 0.29364846, 0.11275876, 0.2896085 ], [ 0.29977678, 0.11356089, 0.29254823], [ 0.30592213, 0.11432553, 0.29542718], [ 0.31208435, 0.11505284, 0.29824485], [ 0.31826327, 0.1157429 , 0.30100076], [ 0.32445869, 0.11639585, 0.30369448], [ 0.33067031, 0.11701189, 0.30632563], [ 0.33689808, 0.11759095, 0.3088938 ], [ 0.34314168, 0.11813362, 0.31139721], [ 0.34940101, 0.11863987, 0.3138355 ], [ 0.355676 , 0.11910909, 0.31620996], [ 0.36196644, 0.1195413 , 0.31852037], [ 0.36827206, 0.11993653, 0.32076656], [ 0.37459292, 0.12029443, 0.32294825], [ 0.38092887, 0.12061482, 0.32506528], [ 0.38727975, 0.12089756, 0.3271175 ], [ 0.39364518, 0.12114272, 0.32910494], [ 0.40002537, 0.12134964, 0.33102734], [ 0.40642019, 0.12151801, 0.33288464], [ 0.41282936, 0.12164769, 0.33467689], [ 0.41925278, 0.12173833, 0.33640407], [ 0.42569057, 0.12178916, 0.33806605], [ 0.43214263, 0.12179973, 0.33966284], [ 0.43860848, 0.12177004, 0.34119475], [ 0.44508855, 0.12169883, 0.34266151], [ 0.45158266, 0.12158557, 0.34406324], [ 0.45809049, 0.12142996, 0.34540024], [ 0.46461238, 0.12123063, 0.34667231], [ 0.47114798, 0.12098721, 0.34787978], [ 0.47769736, 0.12069864, 0.34902273], [ 0.48426077, 0.12036349, 0.35010104], [ 0.49083761, 0.11998161, 0.35111537], [ 0.49742847, 0.11955087, 0.35206533], [ 0.50403286, 0.11907081, 0.35295152], [ 0.51065109, 0.11853959, 0.35377385], [ 0.51728314, 0.1179558 , 0.35453252], [ 0.52392883, 0.11731817, 0.35522789], [ 0.53058853, 0.11662445, 0.35585982], [ 0.53726173, 0.11587369, 0.35642903], [ 0.54394898, 0.11506307, 0.35693521], [ 0.5506426 , 0.11420757, 0.35737863], [ 0.55734473, 0.11330456, 0.35775059], [ 0.56405586, 0.11235265, 0.35804813], [ 0.57077365, 0.11135597, 0.35827146], [ 0.5774991 , 0.11031233, 0.35841679], [ 0.58422945, 0.10922707, 0.35848469], [ 0.59096382, 0.10810205, 0.35847347], [ 0.59770215, 0.10693774, 0.35838029], [ 0.60444226, 0.10573912, 0.35820487], [ 0.61118304, 0.10450943, 0.35794557], [ 0.61792306, 0.10325288, 0.35760108], [ 0.62466162, 0.10197244, 0.35716891], [ 0.63139686, 0.10067417, 0.35664819], [ 0.63812122, 0.09938212, 0.35603757], [ 0.64483795, 0.0980891 , 0.35533555], [ 0.65154562, 0.09680192, 0.35454107], [ 0.65824241, 0.09552918, 0.3536529 ], [ 0.66492652, 0.09428017, 0.3526697 ], [ 0.67159578, 0.09306598, 0.35159077], [ 0.67824099, 0.09192342, 0.3504148 ], [ 0.684863 , 0.09085633, 0.34914061], [ 0.69146268, 0.0898675 , 0.34776864], [ 0.69803757, 0.08897226, 0.3462986 ], [ 0.70457834, 0.0882129 , 0.34473046], [ 0.71108138, 0.08761223, 0.3430635 ], [ 0.7175507 , 0.08716212, 0.34129974], [ 0.72398193, 0.08688725, 0.33943958], [ 0.73035829, 0.0868623 , 0.33748452], [ 0.73669146, 0.08704683, 0.33543669], [ 0.74297501, 0.08747196, 0.33329799], [ 0.74919318, 0.08820542, 0.33107204], [ 0.75535825, 0.08919792, 0.32876184], [ 0.76145589, 0.09050716, 0.32637117], [ 0.76748424, 0.09213602, 0.32390525], [ 0.77344838, 0.09405684, 0.32136808], [ 0.77932641, 0.09634794, 0.31876642], [ 0.78513609, 0.09892473, 0.31610488], [ 0.79085854, 0.10184672, 0.313391 ], [ 0.7965014 , 0.10506637, 0.31063031], [ 0.80205987, 0.10858333, 0.30783 ], [ 0.80752799, 0.11239964, 0.30499738], [ 0.81291606, 0.11645784, 0.30213802], [ 0.81820481, 0.12080606, 0.29926105], [ 0.82341472, 0.12535343, 0.2963705 ], [ 0.82852822, 0.13014118, 0.29347474], [ 0.83355779, 0.13511035, 0.29057852], [ 0.83850183, 0.14025098, 0.2876878 ], [ 0.84335441, 0.14556683, 0.28480819], [ 0.84813096, 0.15099892, 0.281943 ], [ 0.85281737, 0.15657772, 0.27909826], [ 0.85742602, 0.1622583 , 0.27627462], [ 0.86196552, 0.16801239, 0.27346473], [ 0.86641628, 0.17387796, 0.27070818], [ 0.87079129, 0.17982114, 0.26797378], [ 0.87507281, 0.18587368, 0.26529697], [ 0.87925878, 0.19203259, 0.26268136], [ 0.8833417 , 0.19830556, 0.26014181], [ 0.88731387, 0.20469941, 0.25769539], [ 0.89116859, 0.21121788, 0.2553592 ], [ 0.89490337, 0.21785614, 0.25314362], [ 0.8985026 , 0.22463251, 0.25108745], [ 0.90197527, 0.23152063, 0.24918223], [ 0.90530097, 0.23854541, 0.24748098], [ 0.90848638, 0.24568473, 0.24598324], [ 0.911533 , 0.25292623, 0.24470258], [ 0.9144225 , 0.26028902, 0.24369359], [ 0.91717106, 0.26773821, 0.24294137], [ 0.91978131, 0.27526191, 0.24245973], [ 0.92223947, 0.28287251, 0.24229568], [ 0.92456587, 0.29053388, 0.24242622], [ 0.92676657, 0.29823282, 0.24285536], [ 0.92882964, 0.30598085, 0.24362274], [ 0.93078135, 0.31373977, 0.24468803], [ 0.93262051, 0.3215093 , 0.24606461], [ 0.93435067, 0.32928362, 0.24775328], [ 0.93599076, 0.33703942, 0.24972157], [ 0.93752831, 0.34479177, 0.25199928], [ 0.93899289, 0.35250734, 0.25452808], [ 0.94036561, 0.36020899, 0.25734661], [ 0.94167588, 0.36786594, 0.2603949 ], [ 0.94291042, 0.37549479, 0.26369821], [ 0.94408513, 0.3830811 , 0.26722004], [ 0.94520419, 0.39062329, 0.27094924], [ 0.94625977, 0.39813168, 0.27489742], [ 0.94727016, 0.4055909 , 0.27902322], [ 0.94823505, 0.41300424, 0.28332283], [ 0.94914549, 0.42038251, 0.28780969], [ 0.95001704, 0.42771398, 0.29244728], [ 0.95085121, 0.43500005, 0.29722817], [ 0.95165009, 0.44224144, 0.30214494], [ 0.9524044 , 0.44944853, 0.3072105 ], [ 0.95312556, 0.45661389, 0.31239776], [ 0.95381595, 0.46373781, 0.31769923], [ 0.95447591, 0.47082238, 0.32310953], [ 0.95510255, 0.47787236, 0.32862553], [ 0.95569679, 0.48489115, 0.33421404], [ 0.95626788, 0.49187351, 0.33985601], [ 0.95681685, 0.49882008, 0.34555431], [ 0.9573439 , 0.50573243, 0.35130912], [ 0.95784842, 0.51261283, 0.35711942], [ 0.95833051, 0.51946267, 0.36298589], [ 0.95879054, 0.52628305, 0.36890904], [ 0.95922872, 0.53307513, 0.3748895 ], [ 0.95964538, 0.53983991, 0.38092784], [ 0.96004345, 0.54657593, 0.3870292 ], [ 0.96042097, 0.55328624, 0.39319057], [ 0.96077819, 0.55997184, 0.39941173], [ 0.9611152 , 0.5666337 , 0.40569343], [ 0.96143273, 0.57327231, 0.41203603], [ 0.96173392, 0.57988594, 0.41844491], [ 0.96201757, 0.58647675, 0.42491751], [ 0.96228344, 0.59304598, 0.43145271], [ 0.96253168, 0.5995944 , 0.43805131], [ 0.96276513, 0.60612062, 0.44471698], [ 0.96298491, 0.6126247 , 0.45145074], [ 0.96318967, 0.61910879, 0.45824902], [ 0.96337949, 0.6255736 , 0.46511271], [ 0.96355923, 0.63201624, 0.47204746], [ 0.96372785, 0.63843852, 0.47905028], [ 0.96388426, 0.64484214, 0.4861196 ], [ 0.96403203, 0.65122535, 0.4932578 ], [ 0.96417332, 0.65758729, 0.50046894], [ 0.9643063 , 0.66393045, 0.5077467 ], [ 0.96443322, 0.67025402, 0.51509334], [ 0.96455845, 0.67655564, 0.52251447], [ 0.96467922, 0.68283846, 0.53000231], [ 0.96479861, 0.68910113, 0.53756026], [ 0.96492035, 0.69534192, 0.5451917 ], [ 0.96504223, 0.7015636 , 0.5528892 ], [ 0.96516917, 0.70776351, 0.5606593 ], [ 0.96530224, 0.71394212, 0.56849894], [ 0.96544032, 0.72010124, 0.57640375], [ 0.96559206, 0.72623592, 0.58438387], [ 0.96575293, 0.73235058, 0.59242739], [ 0.96592829, 0.73844258, 0.60053991], [ 0.96612013, 0.74451182, 0.60871954], [ 0.96632832, 0.75055966, 0.61696136], [ 0.96656022, 0.75658231, 0.62527295], [ 0.96681185, 0.76258381, 0.63364277], [ 0.96709183, 0.76855969, 0.64207921], [ 0.96739773, 0.77451297, 0.65057302], [ 0.96773482, 0.78044149, 0.65912731], [ 0.96810471, 0.78634563, 0.66773889], [ 0.96850919, 0.79222565, 0.6764046 ], [ 0.96893132, 0.79809112, 0.68512266], [ 0.96935926, 0.80395415, 0.69383201], [ 0.9698028 , 0.80981139, 0.70252255], [ 0.97025511, 0.81566605, 0.71120296], [ 0.97071849, 0.82151775, 0.71987163], [ 0.97120159, 0.82736371, 0.72851999], [ 0.97169389, 0.83320847, 0.73716071], [ 0.97220061, 0.83905052, 0.74578903], [ 0.97272597, 0.84488881, 0.75440141], [ 0.97327085, 0.85072354, 0.76299805], [ 0.97383206, 0.85655639, 0.77158353], [ 0.97441222, 0.86238689, 0.78015619], [ 0.97501782, 0.86821321, 0.78871034], [ 0.97564391, 0.87403763, 0.79725261], [ 0.97628674, 0.87986189, 0.8057883 ], [ 0.97696114, 0.88568129, 0.81430324], [ 0.97765722, 0.89149971, 0.82280948], [ 0.97837585, 0.89731727, 0.83130786], [ 0.97912374, 0.90313207, 0.83979337], [ 0.979891 , 0.90894778, 0.84827858], [ 0.98067764, 0.91476465, 0.85676611], [ 0.98137749, 0.92061729, 0.86536915] ] _mako_lut = [ [ 0.04503935, 0.01482344, 0.02092227], [ 0.04933018, 0.01709292, 0.02535719], [ 0.05356262, 0.01950702, 0.03018802], [ 0.05774337, 0.02205989, 0.03545515], [ 0.06188095, 0.02474764, 0.04115287], [ 0.06598247, 0.0275665 , 0.04691409], [ 0.07005374, 0.03051278, 0.05264306], [ 0.07409947, 0.03358324, 0.05834631], [ 0.07812339, 0.03677446, 0.06403249], [ 0.08212852, 0.0400833 , 0.06970862], [ 0.08611731, 0.04339148, 0.07538208], [ 0.09009161, 0.04664706, 0.08105568], [ 0.09405308, 0.04985685, 0.08673591], [ 0.09800301, 0.05302279, 0.09242646], [ 0.10194255, 0.05614641, 0.09813162], [ 0.10587261, 0.05922941, 0.103854 ], [ 0.1097942 , 0.06227277, 0.10959847], [ 0.11370826, 0.06527747, 0.11536893], [ 0.11761516, 0.06824548, 0.12116393], [ 0.12151575, 0.07117741, 0.12698763], [ 0.12541095, 0.07407363, 0.1328442 ], [ 0.12930083, 0.07693611, 0.13873064], [ 0.13317849, 0.07976988, 0.14465095], [ 0.13701138, 0.08259683, 0.15060265], [ 0.14079223, 0.08542126, 0.15659379], [ 0.14452486, 0.08824175, 0.16262484], [ 0.14820351, 0.09106304, 0.16869476], [ 0.15183185, 0.09388372, 0.17480366], [ 0.15540398, 0.09670855, 0.18094993], [ 0.15892417, 0.09953561, 0.18713384], [ 0.16238588, 0.10236998, 0.19335329], [ 0.16579435, 0.10520905, 0.19960847], [ 0.16914226, 0.10805832, 0.20589698], [ 0.17243586, 0.11091443, 0.21221911], [ 0.17566717, 0.11378321, 0.21857219], [ 0.17884322, 0.11666074, 0.2249565 ], [ 0.18195582, 0.11955283, 0.23136943], [ 0.18501213, 0.12245547, 0.23781116], [ 0.18800459, 0.12537395, 0.24427914], [ 0.19093944, 0.1283047 , 0.25077369], [ 0.19381092, 0.13125179, 0.25729255], [ 0.19662307, 0.13421303, 0.26383543], [ 0.19937337, 0.13719028, 0.27040111], [ 0.20206187, 0.14018372, 0.27698891], [ 0.20469116, 0.14319196, 0.28359861], [ 0.20725547, 0.14621882, 0.29022775], [ 0.20976258, 0.14925954, 0.29687795], [ 0.21220409, 0.15231929, 0.30354703], [ 0.21458611, 0.15539445, 0.31023563], [ 0.21690827, 0.15848519, 0.31694355], [ 0.21916481, 0.16159489, 0.32366939], [ 0.2213631 , 0.16471913, 0.33041431], [ 0.22349947, 0.1678599 , 0.33717781], [ 0.2255714 , 0.1710185 , 0.34395925], [ 0.22758415, 0.17419169, 0.35075983], [ 0.22953569, 0.17738041, 0.35757941], [ 0.23142077, 0.18058733, 0.3644173 ], [ 0.2332454 , 0.18380872, 0.37127514], [ 0.2350092 , 0.18704459, 0.3781528 ], [ 0.23670785, 0.190297 , 0.38504973], [ 0.23834119, 0.19356547, 0.39196711], [ 0.23991189, 0.19684817, 0.39890581], [ 0.24141903, 0.20014508, 0.4058667 ], [ 0.24286214, 0.20345642, 0.4128484 ], [ 0.24423453, 0.20678459, 0.41985299], [ 0.24554109, 0.21012669, 0.42688124], [ 0.2467815 , 0.21348266, 0.43393244], [ 0.24795393, 0.21685249, 0.4410088 ], [ 0.24905614, 0.22023618, 0.448113 ], [ 0.25007383, 0.22365053, 0.45519562], [ 0.25098926, 0.22710664, 0.46223892], [ 0.25179696, 0.23060342, 0.46925447], [ 0.25249346, 0.23414353, 0.47623196], [ 0.25307401, 0.23772973, 0.48316271], [ 0.25353152, 0.24136961, 0.49001976], [ 0.25386167, 0.24506548, 0.49679407], [ 0.25406082, 0.2488164 , 0.50348932], [ 0.25412435, 0.25262843, 0.51007843], [ 0.25404842, 0.25650743, 0.51653282], [ 0.25383134, 0.26044852, 0.52286845], [ 0.2534705 , 0.26446165, 0.52903422], [ 0.25296722, 0.2685428 , 0.53503572], [ 0.2523226 , 0.27269346, 0.54085315], [ 0.25153974, 0.27691629, 0.54645752], [ 0.25062402, 0.28120467, 0.55185939], [ 0.24958205, 0.28556371, 0.55701246], [ 0.24842386, 0.28998148, 0.56194601], [ 0.24715928, 0.29446327, 0.56660884], [ 0.24580099, 0.29899398, 0.57104399], [ 0.24436202, 0.30357852, 0.57519929], [ 0.24285591, 0.30819938, 0.57913247], [ 0.24129828, 0.31286235, 0.58278615], [ 0.23970131, 0.3175495 , 0.5862272 ], [ 0.23807973, 0.32226344, 0.58941872], [ 0.23644557, 0.32699241, 0.59240198], [ 0.2348113 , 0.33173196, 0.59518282], [ 0.23318874, 0.33648036, 0.59775543], [ 0.2315855 , 0.34122763, 0.60016456], [ 0.23001121, 0.34597357, 0.60240251], [ 0.2284748 , 0.35071512, 0.6044784 ], [ 0.22698081, 0.35544612, 0.60642528], [ 0.22553305, 0.36016515, 0.60825252], [ 0.22413977, 0.36487341, 0.60994938], [ 0.22280246, 0.36956728, 0.61154118], [ 0.22152555, 0.37424409, 0.61304472], [ 0.22030752, 0.37890437, 0.61446646], [ 0.2191538 , 0.38354668, 0.61581561], [ 0.21806257, 0.38817169, 0.61709794], [ 0.21703799, 0.39277882, 0.61831922], [ 0.21607792, 0.39736958, 0.61948028], [ 0.21518463, 0.40194196, 0.62059763], [ 0.21435467, 0.40649717, 0.62167507], [ 0.21358663, 0.41103579, 0.62271724], [ 0.21288172, 0.41555771, 0.62373011], [ 0.21223835, 0.42006355, 0.62471794], [ 0.21165312, 0.42455441, 0.62568371], [ 0.21112526, 0.42903064, 0.6266318 ], [ 0.21065161, 0.43349321, 0.62756504], [ 0.21023306, 0.43794288, 0.62848279], [ 0.20985996, 0.44238227, 0.62938329], [ 0.20951045, 0.44680966, 0.63030696], [ 0.20916709, 0.45122981, 0.63124483], [ 0.20882976, 0.45564335, 0.63219599], [ 0.20849798, 0.46005094, 0.63315928], [ 0.20817199, 0.46445309, 0.63413391], [ 0.20785149, 0.46885041, 0.63511876], [ 0.20753716, 0.47324327, 0.63611321], [ 0.20722876, 0.47763224, 0.63711608], [ 0.20692679, 0.48201774, 0.63812656], [ 0.20663156, 0.48640018, 0.63914367], [ 0.20634336, 0.49078002, 0.64016638], [ 0.20606303, 0.49515755, 0.6411939 ], [ 0.20578999, 0.49953341, 0.64222457], [ 0.20552612, 0.50390766, 0.64325811], [ 0.20527189, 0.50828072, 0.64429331], [ 0.20502868, 0.51265277, 0.64532947], [ 0.20479718, 0.51702417, 0.64636539], [ 0.20457804, 0.52139527, 0.64739979], [ 0.20437304, 0.52576622, 0.64843198], [ 0.20418396, 0.53013715, 0.64946117], [ 0.20401238, 0.53450825, 0.65048638], [ 0.20385896, 0.53887991, 0.65150606], [ 0.20372653, 0.54325208, 0.65251978], [ 0.20361709, 0.5476249 , 0.6535266 ], [ 0.20353258, 0.55199854, 0.65452542], [ 0.20347472, 0.55637318, 0.655515 ], [ 0.20344718, 0.56074869, 0.65649508], [ 0.20345161, 0.56512531, 0.65746419], [ 0.20349089, 0.56950304, 0.65842151], [ 0.20356842, 0.57388184, 0.65936642], [ 0.20368663, 0.57826181, 0.66029768], [ 0.20384884, 0.58264293, 0.6612145 ], [ 0.20405904, 0.58702506, 0.66211645], [ 0.20431921, 0.59140842, 0.66300179], [ 0.20463464, 0.59579264, 0.66387079], [ 0.20500731, 0.60017798, 0.66472159], [ 0.20544449, 0.60456387, 0.66555409], [ 0.20596097, 0.60894927, 0.66636568], [ 0.20654832, 0.61333521, 0.66715744], [ 0.20721003, 0.61772167, 0.66792838], [ 0.20795035, 0.62210845, 0.66867802], [ 0.20877302, 0.62649546, 0.66940555], [ 0.20968223, 0.63088252, 0.6701105 ], [ 0.21068163, 0.63526951, 0.67079211], [ 0.21177544, 0.63965621, 0.67145005], [ 0.21298582, 0.64404072, 0.67208182], [ 0.21430361, 0.64842404, 0.67268861], [ 0.21572716, 0.65280655, 0.67326978], [ 0.21726052, 0.65718791, 0.6738255 ], [ 0.21890636, 0.66156803, 0.67435491], [ 0.220668 , 0.66594665, 0.67485792], [ 0.22255447, 0.67032297, 0.67533374], [ 0.22458372, 0.67469531, 0.67578061], [ 0.22673713, 0.67906542, 0.67620044], [ 0.22901625, 0.6834332 , 0.67659251], [ 0.23142316, 0.68779836, 0.67695703], [ 0.23395924, 0.69216072, 0.67729378], [ 0.23663857, 0.69651881, 0.67760151], [ 0.23946645, 0.70087194, 0.67788018], [ 0.24242624, 0.70522162, 0.67813088], [ 0.24549008, 0.70957083, 0.67835215], [ 0.24863372, 0.71392166, 0.67854868], [ 0.25187832, 0.71827158, 0.67872193], [ 0.25524083, 0.72261873, 0.67887024], [ 0.25870947, 0.72696469, 0.67898912], [ 0.26229238, 0.73130855, 0.67907645], [ 0.26604085, 0.73564353, 0.67914062], [ 0.26993099, 0.73997282, 0.67917264], [ 0.27397488, 0.74429484, 0.67917096], [ 0.27822463, 0.74860229, 0.67914468], [ 0.28264201, 0.75290034, 0.67907959], [ 0.2873016 , 0.75717817, 0.67899164], [ 0.29215894, 0.76144162, 0.67886578], [ 0.29729823, 0.76567816, 0.67871894], [ 0.30268199, 0.76989232, 0.67853896], [ 0.30835665, 0.77407636, 0.67833512], [ 0.31435139, 0.77822478, 0.67811118], [ 0.3206671 , 0.78233575, 0.67786729], [ 0.32733158, 0.78640315, 0.67761027], [ 0.33437168, 0.79042043, 0.67734882], [ 0.34182112, 0.79437948, 0.67709394], [ 0.34968889, 0.79827511, 0.67685638], [ 0.35799244, 0.80210037, 0.67664969], [ 0.36675371, 0.80584651, 0.67649539], [ 0.3759816 , 0.80950627, 0.67641393], [ 0.38566792, 0.81307432, 0.67642947], [ 0.39579804, 0.81654592, 0.67656899], [ 0.40634556, 0.81991799, 0.67686215], [ 0.41730243, 0.82318339, 0.67735255], [ 0.4285828 , 0.82635051, 0.6780564 ], [ 0.44012728, 0.82942353, 0.67900049], [ 0.45189421, 0.83240398, 0.68021733], [ 0.46378379, 0.83530763, 0.6817062 ], [ 0.47573199, 0.83814472, 0.68347352], [ 0.48769865, 0.84092197, 0.68552698], [ 0.49962354, 0.84365379, 0.68783929], [ 0.5114027 , 0.8463718 , 0.69029789], [ 0.52301693, 0.84908401, 0.69288545], [ 0.53447549, 0.85179048, 0.69561066], [ 0.54578602, 0.8544913 , 0.69848331], [ 0.55695565, 0.85718723, 0.70150427], [ 0.56798832, 0.85987893, 0.70468261], [ 0.57888639, 0.86256715, 0.70802931], [ 0.5896541 , 0.8652532 , 0.71154204], [ 0.60028928, 0.86793835, 0.71523675], [ 0.61079441, 0.87062438, 0.71910895], [ 0.62116633, 0.87331311, 0.72317003], [ 0.63140509, 0.87600675, 0.72741689], [ 0.64150735, 0.87870746, 0.73185717], [ 0.65147219, 0.8814179 , 0.73648495], [ 0.66129632, 0.8841403 , 0.74130658], [ 0.67097934, 0.88687758, 0.74631123], [ 0.68051833, 0.88963189, 0.75150483], [ 0.68991419, 0.89240612, 0.75687187], [ 0.69916533, 0.89520211, 0.76241714], [ 0.70827373, 0.89802257, 0.76812286], [ 0.71723995, 0.90086891, 0.77399039], [ 0.72606665, 0.90374337, 0.7800041 ], [ 0.73475675, 0.90664718, 0.78615802], [ 0.74331358, 0.90958151, 0.79244474], [ 0.75174143, 0.91254787, 0.79884925], [ 0.76004473, 0.91554656, 0.80536823], [ 0.76827704, 0.91856549, 0.81196513], [ 0.77647029, 0.921603 , 0.81855729], [ 0.78462009, 0.92466151, 0.82514119], [ 0.79273542, 0.92773848, 0.83172131], [ 0.8008109 , 0.93083672, 0.83829355], [ 0.80885107, 0.93395528, 0.84485982], [ 0.81685878, 0.9370938 , 0.85142101], [ 0.82483206, 0.94025378, 0.8579751 ], [ 0.83277661, 0.94343371, 0.86452477], [ 0.84069127, 0.94663473, 0.87106853], [ 0.84857662, 0.9498573 , 0.8776059 ], [ 0.8564431 , 0.95309792, 0.88414253], [ 0.86429066, 0.95635719, 0.89067759], [ 0.87218969, 0.95960708, 0.89725384] ] _vlag_lut = [ [ 0.13850039, 0.41331206, 0.74052025], [ 0.15077609, 0.41762684, 0.73970427], [ 0.16235219, 0.4219191 , 0.7389667 ], [ 0.1733322 , 0.42619024, 0.73832537], [ 0.18382538, 0.43044226, 0.73776764], [ 0.19394034, 0.4346772 , 0.73725867], [ 0.20367115, 0.43889576, 0.73685314], [ 0.21313625, 0.44310003, 0.73648045], [ 0.22231173, 0.44729079, 0.73619681], [ 0.23125148, 0.45146945, 0.73597803], [ 0.23998101, 0.45563715, 0.7358223 ], [ 0.24853358, 0.45979489, 0.73571524], [ 0.25691416, 0.4639437 , 0.73566943], [ 0.26513894, 0.46808455, 0.73568319], [ 0.27322194, 0.47221835, 0.73575497], [ 0.28117543, 0.47634598, 0.73588332], [ 0.28901021, 0.48046826, 0.73606686], [ 0.2967358 , 0.48458597, 0.73630433], [ 0.30436071, 0.48869986, 0.73659451], [ 0.3118955 , 0.49281055, 0.73693255], [ 0.31935389, 0.49691847, 0.73730851], [ 0.32672701, 0.5010247 , 0.73774013], [ 0.33402607, 0.50512971, 0.73821941], [ 0.34125337, 0.50923419, 0.73874905], [ 0.34840921, 0.51333892, 0.73933402], [ 0.35551826, 0.51744353, 0.73994642], [ 0.3625676 , 0.52154929, 0.74060763], [ 0.36956356, 0.52565656, 0.74131327], [ 0.37649902, 0.52976642, 0.74207698], [ 0.38340273, 0.53387791, 0.74286286], [ 0.39025859, 0.53799253, 0.7436962 ], [ 0.39706821, 0.54211081, 0.744578 ], [ 0.40384046, 0.54623277, 0.74549872], [ 0.41058241, 0.55035849, 0.74645094], [ 0.41728385, 0.55448919, 0.74745174], [ 0.42395178, 0.55862494, 0.74849357], [ 0.4305964 , 0.56276546, 0.74956387], [ 0.4372044 , 0.56691228, 0.75068412], [ 0.4437909 , 0.57106468, 0.75183427], [ 0.45035117, 0.5752235 , 0.75302312], [ 0.45687824, 0.57938983, 0.75426297], [ 0.46339713, 0.58356191, 0.75551816], [ 0.46988778, 0.58774195, 0.75682037], [ 0.47635605, 0.59192986, 0.75816245], [ 0.48281101, 0.5961252 , 0.75953212], [ 0.4892374 , 0.60032986, 0.76095418], [ 0.49566225, 0.60454154, 0.76238852], [ 0.50206137, 0.60876307, 0.76387371], [ 0.50845128, 0.61299312, 0.76538551], [ 0.5148258 , 0.61723272, 0.76693475], [ 0.52118385, 0.62148236, 0.76852436], [ 0.52753571, 0.62574126, 0.77013939], [ 0.53386831, 0.63001125, 0.77180152], [ 0.54020159, 0.63429038, 0.7734803 ], [ 0.54651272, 0.63858165, 0.77521306], [ 0.55282975, 0.64288207, 0.77695608], [ 0.55912585, 0.64719519, 0.77875327], [ 0.56542599, 0.65151828, 0.78056551], [ 0.57170924, 0.65585426, 0.78242747], [ 0.57799572, 0.6602009 , 0.78430751], [ 0.58426817, 0.66456073, 0.78623458], [ 0.590544 , 0.66893178, 0.78818117], [ 0.59680758, 0.67331643, 0.79017369], [ 0.60307553, 0.67771273, 0.79218572], [ 0.60934065, 0.68212194, 0.79422987], [ 0.61559495, 0.68654548, 0.7963202 ], [ 0.62185554, 0.69098125, 0.79842918], [ 0.62810662, 0.69543176, 0.80058381], [ 0.63436425, 0.69989499, 0.80275812], [ 0.64061445, 0.70437326, 0.80497621], [ 0.6468706 , 0.70886488, 0.80721641], [ 0.65312213, 0.7133717 , 0.80949719], [ 0.65937818, 0.71789261, 0.81180392], [ 0.66563334, 0.72242871, 0.81414642], [ 0.67189155, 0.72697967, 0.81651872], [ 0.67815314, 0.73154569, 0.81892097], [ 0.68441395, 0.73612771, 0.82136094], [ 0.69068321, 0.74072452, 0.82382353], [ 0.69694776, 0.7453385 , 0.82633199], [ 0.70322431, 0.74996721, 0.8288583 ], [ 0.70949595, 0.75461368, 0.83143221], [ 0.7157774 , 0.75927574, 0.83402904], [ 0.72206299, 0.76395461, 0.83665922], [ 0.72835227, 0.76865061, 0.8393242 ], [ 0.73465238, 0.7733628 , 0.84201224], [ 0.74094862, 0.77809393, 0.84474951], [ 0.74725683, 0.78284158, 0.84750915], [ 0.75357103, 0.78760701, 0.85030217], [ 0.75988961, 0.79239077, 0.85313207], [ 0.76621987, 0.79719185, 0.85598668], [ 0.77255045, 0.8020125 , 0.85888658], [ 0.77889241, 0.80685102, 0.86181298], [ 0.78524572, 0.81170768, 0.86476656], [ 0.79159841, 0.81658489, 0.86776906], [ 0.79796459, 0.82148036, 0.8707962 ], [ 0.80434168, 0.82639479, 0.87385315], [ 0.8107221 , 0.83132983, 0.87695392], [ 0.81711301, 0.8362844 , 0.88008641], [ 0.82351479, 0.84125863, 0.88325045], [ 0.82992772, 0.84625263, 0.88644594], [ 0.83634359, 0.85126806, 0.8896878 ], [ 0.84277295, 0.85630293, 0.89295721], [ 0.84921192, 0.86135782, 0.89626076], [ 0.85566206, 0.866432 , 0.89959467], [ 0.86211514, 0.87152627, 0.90297183], [ 0.86857483, 0.87663856, 0.90638248], [ 0.87504231, 0.88176648, 0.90981938], [ 0.88151194, 0.88690782, 0.91328493], [ 0.88797938, 0.89205857, 0.91677544], [ 0.89443865, 0.89721298, 0.9202854 ], [ 0.90088204, 0.90236294, 0.92380601], [ 0.90729768, 0.90749778, 0.92732797], [ 0.91367037, 0.91260329, 0.93083814], [ 0.91998105, 0.91766106, 0.93431861], [ 0.92620596, 0.92264789, 0.93774647], [ 0.93231683, 0.9275351 , 0.94109192], [ 0.93827772, 0.9322888 , 0.94432312], [ 0.94404755, 0.93686925, 0.94740137], [ 0.94958284, 0.94123072, 0.95027696], [ 0.95482682, 0.9453245 , 0.95291103], [ 0.9597248 , 0.94909728, 0.95525103], [ 0.96422552, 0.95249273, 0.95723271], [ 0.96826161, 0.95545812, 0.95882188], [ 0.97178458, 0.95793984, 0.95995705], [ 0.97474105, 0.95989142, 0.96059997], [ 0.97708604, 0.96127366, 0.96071853], [ 0.97877855, 0.96205832, 0.96030095], [ 0.97978484, 0.96222949, 0.95935496], [ 0.9805997 , 0.96155216, 0.95813083], [ 0.98152619, 0.95993719, 0.95639322], [ 0.9819726 , 0.95766608, 0.95399269], [ 0.98191855, 0.9547873 , 0.95098107], [ 0.98138514, 0.95134771, 0.94740644], [ 0.98040845, 0.94739906, 0.94332125], [ 0.97902107, 0.94300131, 0.93878672], [ 0.97729348, 0.93820409, 0.93385135], [ 0.9752533 , 0.933073 , 0.92858252], [ 0.97297834, 0.92765261, 0.92302309], [ 0.97049104, 0.92200317, 0.91723505], [ 0.96784372, 0.91616744, 0.91126063], [ 0.96507281, 0.91018664, 0.90514124], [ 0.96222034, 0.90409203, 0.89890756], [ 0.9593079 , 0.89791478, 0.89259122], [ 0.95635626, 0.89167908, 0.88621654], [ 0.95338303, 0.88540373, 0.87980238], [ 0.95040174, 0.87910333, 0.87336339], [ 0.94742246, 0.87278899, 0.86691076], [ 0.94445249, 0.86646893, 0.86045277], [ 0.94150476, 0.86014606, 0.85399191], [ 0.93857394, 0.85382798, 0.84753642], [ 0.93566206, 0.84751766, 0.84108935], [ 0.93277194, 0.8412164 , 0.83465197], [ 0.92990106, 0.83492672, 0.82822708], [ 0.92704736, 0.82865028, 0.82181656], [ 0.92422703, 0.82238092, 0.81541333], [ 0.92142581, 0.81612448, 0.80902415], [ 0.91864501, 0.80988032, 0.80264838], [ 0.91587578, 0.80365187, 0.79629001], [ 0.9131367 , 0.79743115, 0.78994 ], [ 0.91041602, 0.79122265, 0.78360361], [ 0.90771071, 0.78502727, 0.77728196], [ 0.90501581, 0.77884674, 0.7709771 ], [ 0.90235365, 0.77267117, 0.76467793], [ 0.8997019 , 0.76650962, 0.75839484], [ 0.89705346, 0.76036481, 0.752131 ], [ 0.89444021, 0.75422253, 0.74587047], [ 0.89183355, 0.74809474, 0.73962689], [ 0.88923216, 0.74198168, 0.73340061], [ 0.88665892, 0.73587283, 0.72717995], [ 0.88408839, 0.72977904, 0.72097718], [ 0.88153537, 0.72369332, 0.71478461], [ 0.87899389, 0.7176179 , 0.70860487], [ 0.87645157, 0.71155805, 0.7024439 ], [ 0.8739399 , 0.70549893, 0.6962854 ], [ 0.87142626, 0.6994551 , 0.69014561], [ 0.8689268 , 0.69341868, 0.68401597], [ 0.86643562, 0.687392 , 0.67789917], [ 0.86394434, 0.68137863, 0.67179927], [ 0.86147586, 0.67536728, 0.665704 ], [ 0.85899928, 0.66937226, 0.6596292 ], [ 0.85654668, 0.66337773, 0.6535577 ], [ 0.85408818, 0.65739772, 0.64750494], [ 0.85164413, 0.65142189, 0.64145983], [ 0.84920091, 0.6454565 , 0.63542932], [ 0.84676427, 0.63949827, 0.62941 ], [ 0.84433231, 0.63354773, 0.62340261], [ 0.84190106, 0.62760645, 0.61740899], [ 0.83947935, 0.62166951, 0.61142404], [ 0.8370538 , 0.61574332, 0.60545478], [ 0.83463975, 0.60981951, 0.59949247], [ 0.83221877, 0.60390724, 0.593547 ], [ 0.82980985, 0.59799607, 0.58760751], [ 0.82740268, 0.59209095, 0.58167944], [ 0.82498638, 0.5861973 , 0.57576866], [ 0.82258181, 0.5803034 , 0.56986307], [ 0.82016611, 0.57442123, 0.56397539], [ 0.81776305, 0.56853725, 0.55809173], [ 0.81534551, 0.56266602, 0.55222741], [ 0.81294293, 0.55679056, 0.5463651 ], [ 0.81052113, 0.55092973, 0.54052443], [ 0.80811509, 0.54506305, 0.53468464], [ 0.80568952, 0.53921036, 0.52886622], [ 0.80327506, 0.53335335, 0.52305077], [ 0.80084727, 0.52750583, 0.51725256], [ 0.79842217, 0.5216578 , 0.51146173], [ 0.79599382, 0.51581223, 0.50568155], [ 0.79355781, 0.50997127, 0.49991444], [ 0.79112596, 0.50412707, 0.49415289], [ 0.78867442, 0.49829386, 0.48841129], [ 0.7862306 , 0.49245398, 0.48267247], [ 0.7837687 , 0.48662309, 0.47695216], [ 0.78130809, 0.4807883 , 0.47123805], [ 0.77884467, 0.47495151, 0.46553236], [ 0.77636283, 0.46912235, 0.45984473], [ 0.77388383, 0.46328617, 0.45416141], [ 0.77138912, 0.45745466, 0.44849398], [ 0.76888874, 0.45162042, 0.44283573], [ 0.76638802, 0.44577901, 0.43718292], [ 0.76386116, 0.43994762, 0.43155211], [ 0.76133542, 0.43410655, 0.42592523], [ 0.75880631, 0.42825801, 0.42030488], [ 0.75624913, 0.42241905, 0.41470727], [ 0.7536919 , 0.41656866, 0.40911347], [ 0.75112748, 0.41071104, 0.40352792], [ 0.74854331, 0.40485474, 0.3979589 ], [ 0.74594723, 0.39899309, 0.39240088], [ 0.74334332, 0.39312199, 0.38685075], [ 0.74073277, 0.38723941, 0.3813074 ], [ 0.73809409, 0.38136133, 0.37578553], [ 0.73544692, 0.37547129, 0.37027123], [ 0.73278943, 0.36956954, 0.36476549], [ 0.73011829, 0.36365761, 0.35927038], [ 0.72743485, 0.35773314, 0.35378465], [ 0.72472722, 0.35180504, 0.34831662], [ 0.72200473, 0.34586421, 0.34285937], [ 0.71927052, 0.33990649, 0.33741033], [ 0.71652049, 0.33393396, 0.33197219], [ 0.71375362, 0.32794602, 0.32654545], [ 0.71096951, 0.32194148, 0.32113016], [ 0.70816772, 0.31591904, 0.31572637], [ 0.70534784, 0.30987734, 0.31033414], [ 0.70250944, 0.30381489, 0.30495353], [ 0.69965211, 0.2977301 , 0.2995846 ], [ 0.6967754 , 0.29162126, 0.29422741], [ 0.69388446, 0.28548074, 0.28887769], [ 0.69097561, 0.2793096 , 0.28353795], [ 0.68803513, 0.27311993, 0.27821876], [ 0.6850794 , 0.26689144, 0.27290694], [ 0.682108 , 0.26062114, 0.26760246], [ 0.67911013, 0.2543177 , 0.26231367], [ 0.67609393, 0.24796818, 0.25703372], [ 0.67305921, 0.24156846, 0.25176238], [ 0.67000176, 0.23511902, 0.24650278], [ 0.66693423, 0.22859879, 0.24124404], [ 0.6638441 , 0.22201742, 0.2359961 ], [ 0.66080672, 0.21526712, 0.23069468] ] _icefire_lut = [ [ 0.73936227, 0.90443867, 0.85757238], [ 0.72888063, 0.89639109, 0.85488394], [ 0.71834255, 0.88842162, 0.8521605 ], [ 0.70773866, 0.88052939, 0.849422 ], [ 0.69706215, 0.87271313, 0.84668315], [ 0.68629021, 0.86497329, 0.84398721], [ 0.67543654, 0.85730617, 0.84130969], [ 0.66448539, 0.84971123, 0.83868005], [ 0.65342679, 0.84218728, 0.83611512], [ 0.64231804, 0.83471867, 0.83358584], [ 0.63117745, 0.827294 , 0.83113431], [ 0.62000484, 0.81991069, 0.82876741], [ 0.60879435, 0.81256797, 0.82648905], [ 0.59754118, 0.80526458, 0.82430414], [ 0.58624247, 0.79799884, 0.82221573], [ 0.57489525, 0.7907688 , 0.82022901], [ 0.56349779, 0.78357215, 0.81834861], [ 0.55204294, 0.77640827, 0.81657563], [ 0.54052516, 0.76927562, 0.81491462], [ 0.52894085, 0.76217215, 0.81336913], [ 0.51728854, 0.75509528, 0.81194156], [ 0.50555676, 0.74804469, 0.81063503], [ 0.49373871, 0.7410187 , 0.80945242], [ 0.48183174, 0.73401449, 0.80839675], [ 0.46982587, 0.72703075, 0.80747097], [ 0.45770893, 0.72006648, 0.80667756], [ 0.44547249, 0.71311941, 0.80601991], [ 0.43318643, 0.70617126, 0.80549278], [ 0.42110294, 0.69916972, 0.80506683], [ 0.40925101, 0.69211059, 0.80473246], [ 0.3976693 , 0.68498786, 0.80448272], [ 0.38632002, 0.67781125, 0.80431024], [ 0.37523981, 0.67057537, 0.80420832], [ 0.36442578, 0.66328229, 0.80417474], [ 0.35385939, 0.65593699, 0.80420591], [ 0.34358916, 0.64853177, 0.8043 ], [ 0.33355526, 0.64107876, 0.80445484], [ 0.32383062, 0.63356578, 0.80467091], [ 0.31434372, 0.62600624, 0.8049475 ], [ 0.30516161, 0.618389 , 0.80528692], [ 0.29623491, 0.61072284, 0.80569021], [ 0.28759072, 0.60300319, 0.80616055], [ 0.27923924, 0.59522877, 0.80669803], [ 0.27114651, 0.5874047 , 0.80730545], [ 0.26337153, 0.57952055, 0.80799113], [ 0.25588696, 0.57157984, 0.80875922], [ 0.248686 , 0.56358255, 0.80961366], [ 0.24180668, 0.55552289, 0.81055123], [ 0.23526251, 0.54739477, 0.8115939 ], [ 0.22921445, 0.53918506, 0.81267292], [ 0.22397687, 0.53086094, 0.8137141 ], [ 0.21977058, 0.52241482, 0.81457651], [ 0.21658989, 0.51384321, 0.81528511], [ 0.21452772, 0.50514155, 0.81577278], [ 0.21372783, 0.49630865, 0.81589566], [ 0.21409503, 0.48734861, 0.81566163], [ 0.2157176 , 0.47827123, 0.81487615], [ 0.21842857, 0.46909168, 0.81351614], [ 0.22211705, 0.45983212, 0.81146983], [ 0.22665681, 0.45052233, 0.80860217], [ 0.23176013, 0.44119137, 0.80494325], [ 0.23727775, 0.43187704, 0.80038017], [ 0.24298285, 0.42261123, 0.79493267], [ 0.24865068, 0.41341842, 0.78869164], [ 0.25423116, 0.40433127, 0.78155831], [ 0.25950239, 0.39535521, 0.77376848], [ 0.2644736 , 0.38651212, 0.76524809], [ 0.26901584, 0.37779582, 0.75621942], [ 0.27318141, 0.36922056, 0.746605 ], [ 0.27690355, 0.3607736 , 0.73659374], [ 0.28023585, 0.35244234, 0.72622103], [ 0.28306009, 0.34438449, 0.71500731], [ 0.28535896, 0.33660243, 0.70303975], [ 0.28708711, 0.32912157, 0.69034504], [ 0.28816354, 0.32200604, 0.67684067], [ 0.28862749, 0.31519824, 0.66278813], [ 0.28847904, 0.30869064, 0.6482815 ], [ 0.28770912, 0.30250126, 0.63331265], [ 0.28640325, 0.29655509, 0.61811374], [ 0.28458943, 0.29082155, 0.60280913], [ 0.28233561, 0.28527482, 0.58742866], [ 0.27967038, 0.2798938 , 0.57204225], [ 0.27665361, 0.27465357, 0.55667809], [ 0.27332564, 0.2695165 , 0.54145387], [ 0.26973851, 0.26447054, 0.52634916], [ 0.2659204 , 0.25949691, 0.511417 ], [ 0.26190145, 0.25458123, 0.49668768], [ 0.2577151 , 0.24971691, 0.48214874], [ 0.25337618, 0.24490494, 0.46778758], [ 0.24890842, 0.24013332, 0.45363816], [ 0.24433654, 0.23539226, 0.4397245 ], [ 0.23967922, 0.23067729, 0.4260591 ], [ 0.23495608, 0.22598894, 0.41262952], [ 0.23018113, 0.22132414, 0.39945577], [ 0.22534609, 0.21670847, 0.38645794], [ 0.22048761, 0.21211723, 0.37372555], [ 0.2156198 , 0.20755389, 0.36125301], [ 0.21074637, 0.20302717, 0.34903192], [ 0.20586893, 0.19855368, 0.33701661], [ 0.20101757, 0.19411573, 0.32529173], [ 0.19619947, 0.18972425, 0.31383846], [ 0.19140726, 0.18540157, 0.30260777], [ 0.1866769 , 0.1811332 , 0.29166583], [ 0.18201285, 0.17694992, 0.28088776], [ 0.17745228, 0.17282141, 0.27044211], [ 0.17300684, 0.16876921, 0.26024893], [ 0.16868273, 0.16479861, 0.25034479], [ 0.16448691, 0.16091728, 0.24075373], [ 0.16043195, 0.15714351, 0.23141745], [ 0.15652427, 0.15348248, 0.22238175], [ 0.15277065, 0.14994111, 0.21368395], [ 0.14918274, 0.14653431, 0.20529486], [ 0.14577095, 0.14327403, 0.19720829], [ 0.14254381, 0.14016944, 0.18944326], [ 0.13951035, 0.13723063, 0.18201072], [ 0.13667798, 0.13446606, 0.17493774], [ 0.13405762, 0.13188822, 0.16820842], [ 0.13165767, 0.12950667, 0.16183275], [ 0.12948748, 0.12733187, 0.15580631], [ 0.12755435, 0.1253723 , 0.15014098], [ 0.12586516, 0.12363617, 0.1448459 ], [ 0.12442647, 0.12213143, 0.13992571], [ 0.12324241, 0.12086419, 0.13539995], [ 0.12232067, 0.11984278, 0.13124644], [ 0.12166209, 0.11907077, 0.12749671], [ 0.12126982, 0.11855309, 0.12415079], [ 0.12114244, 0.11829179, 0.1212385 ], [ 0.12127766, 0.11828837, 0.11878534], [ 0.12284806, 0.1179729 , 0.11772022], [ 0.12619498, 0.11721796, 0.11770203], [ 0.129968 , 0.11663788, 0.11792377], [ 0.13410011, 0.11625146, 0.11839138], [ 0.13855459, 0.11606618, 0.11910584], [ 0.14333775, 0.11607038, 0.1200606 ], [ 0.148417 , 0.11626929, 0.12125453], [ 0.15377389, 0.11666192, 0.12268364], [ 0.15941427, 0.11723486, 0.12433911], [ 0.16533376, 0.11797856, 0.12621303], [ 0.17152547, 0.11888403, 0.12829735], [ 0.17797765, 0.11994436, 0.13058435], [ 0.18468769, 0.12114722, 0.13306426], [ 0.19165663, 0.12247737, 0.13572616], [ 0.19884415, 0.12394381, 0.1385669 ], [ 0.20627181, 0.12551883, 0.14157124], [ 0.21394877, 0.12718055, 0.14472604], [ 0.22184572, 0.12893119, 0.14802579], [ 0.22994394, 0.13076731, 0.15146314], [ 0.23823937, 0.13267611, 0.15502793], [ 0.24676041, 0.13462172, 0.15870321], [ 0.25546457, 0.13661751, 0.16248722], [ 0.26433628, 0.13865956, 0.16637301], [ 0.27341345, 0.14070412, 0.17034221], [ 0.28264773, 0.14277192, 0.1743957 ], [ 0.29202272, 0.14486161, 0.17852793], [ 0.30159648, 0.14691224, 0.1827169 ], [ 0.31129002, 0.14897583, 0.18695213], [ 0.32111555, 0.15103351, 0.19119629], [ 0.33107961, 0.1530674 , 0.19543758], [ 0.34119892, 0.15504762, 0.1996803 ], [ 0.35142388, 0.15701131, 0.20389086], [ 0.36178937, 0.1589124 , 0.20807639], [ 0.37229381, 0.16073993, 0.21223189], [ 0.38288348, 0.16254006, 0.2163249 ], [ 0.39359592, 0.16426336, 0.22036577], [ 0.40444332, 0.16588767, 0.22434027], [ 0.41537995, 0.16745325, 0.2282297 ], [ 0.42640867, 0.16894939, 0.23202755], [ 0.43754706, 0.17034847, 0.23572899], [ 0.44878564, 0.1716535 , 0.23932344], [ 0.4601126 , 0.17287365, 0.24278607], [ 0.47151732, 0.17401641, 0.24610337], [ 0.48300689, 0.17506676, 0.2492737 ], [ 0.49458302, 0.17601892, 0.25227688], [ 0.50623876, 0.17687777, 0.255096 ], [ 0.5179623 , 0.17765528, 0.2577162 ], [ 0.52975234, 0.17835232, 0.2601134 ], [ 0.54159776, 0.17898292, 0.26226847], [ 0.55348804, 0.17956232, 0.26416003], [ 0.56541729, 0.18010175, 0.26575971], [ 0.57736669, 0.180631 , 0.26704888], [ 0.58932081, 0.18117827, 0.26800409], [ 0.60127582, 0.18175888, 0.26858488], [ 0.61319563, 0.1824336 , 0.2687872 ], [ 0.62506376, 0.18324015, 0.26858301], [ 0.63681202, 0.18430173, 0.26795276], [ 0.64842603, 0.18565472, 0.26689463], [ 0.65988195, 0.18734638, 0.26543435], [ 0.67111966, 0.18948885, 0.26357955], [ 0.68209194, 0.19216636, 0.26137175], [ 0.69281185, 0.19535326, 0.25887063], [ 0.70335022, 0.19891271, 0.25617971], [ 0.71375229, 0.20276438, 0.25331365], [ 0.72401436, 0.20691287, 0.25027366], [ 0.73407638, 0.21145051, 0.24710661], [ 0.74396983, 0.21631913, 0.24380715], [ 0.75361506, 0.22163653, 0.24043996], [ 0.7630579 , 0.22731637, 0.23700095], [ 0.77222228, 0.23346231, 0.23356628], [ 0.78115441, 0.23998404, 0.23013825], [ 0.78979746, 0.24694858, 0.22678822], [ 0.79819286, 0.25427223, 0.22352658], [ 0.80630444, 0.26198807, 0.22040877], [ 0.81417437, 0.27001406, 0.21744645], [ 0.82177364, 0.27837336, 0.21468316], [ 0.82915955, 0.28696963, 0.21210766], [ 0.83628628, 0.2958499 , 0.20977813], [ 0.84322168, 0.30491136, 0.20766435], [ 0.84995458, 0.31415945, 0.2057863 ], [ 0.85648867, 0.32358058, 0.20415327], [ 0.86286243, 0.33312058, 0.20274969], [ 0.86908321, 0.34276705, 0.20157271], [ 0.87512876, 0.3525416 , 0.20064949], [ 0.88100349, 0.36243385, 0.19999078], [ 0.8866469 , 0.37249496, 0.1997976 ], [ 0.89203964, 0.38273475, 0.20013431], [ 0.89713496, 0.39318156, 0.20121514], [ 0.90195099, 0.40380687, 0.20301555], [ 0.90648379, 0.41460191, 0.20558847], [ 0.9106967 , 0.42557857, 0.20918529], [ 0.91463791, 0.43668557, 0.21367954], [ 0.91830723, 0.44790913, 0.21916352], [ 0.92171507, 0.45922856, 0.22568002], [ 0.92491786, 0.4705936 , 0.23308207], [ 0.92790792, 0.48200153, 0.24145932], [ 0.93073701, 0.49341219, 0.25065486], [ 0.93343918, 0.5048017 , 0.26056148], [ 0.93602064, 0.51616486, 0.27118485], [ 0.93850535, 0.52748892, 0.28242464], [ 0.94092933, 0.53875462, 0.29416042], [ 0.94330011, 0.5499628 , 0.30634189], [ 0.94563159, 0.56110987, 0.31891624], [ 0.94792955, 0.57219822, 0.33184256], [ 0.95020929, 0.5832232 , 0.34508419], [ 0.95247324, 0.59419035, 0.35859866], [ 0.95471709, 0.60510869, 0.37236035], [ 0.95698411, 0.61595766, 0.38629631], [ 0.95923863, 0.62676473, 0.40043317], [ 0.9615041 , 0.6375203 , 0.41474106], [ 0.96371553, 0.64826619, 0.42928335], [ 0.96591497, 0.65899621, 0.44380444], [ 0.96809871, 0.66971662, 0.45830232], [ 0.9702495 , 0.6804394 , 0.47280492], [ 0.9723881 , 0.69115622, 0.48729272], [ 0.97450723, 0.70187358, 0.50178034], [ 0.9766108 , 0.712592 , 0.51626837], [ 0.97871716, 0.72330511, 0.53074053], [ 0.98082222, 0.73401769, 0.54520694], [ 0.9829001 , 0.74474445, 0.5597019 ], [ 0.98497466, 0.75547635, 0.57420239], [ 0.98705581, 0.76621129, 0.58870185], [ 0.98913325, 0.77695637, 0.60321626], [ 0.99119918, 0.78771716, 0.61775821], [ 0.9932672 , 0.79848979, 0.63231691], [ 0.99535958, 0.80926704, 0.64687278], [ 0.99740544, 0.82008078, 0.66150571], [ 0.9992197 , 0.83100723, 0.6764127 ] ] _flare_lut = [ [0.92907237, 0.68878959, 0.50411509], [0.92891402, 0.68494686, 0.50173994], [0.92864754, 0.68116207, 0.4993754], [0.92836112, 0.67738527, 0.49701572], [0.9280599, 0.67361354, 0.49466044], [0.92775569, 0.66983999, 0.49230866], [0.9274375, 0.66607098, 0.48996097], [0.927111, 0.66230315, 0.48761688], [0.92677996, 0.6585342, 0.485276], [0.92644317, 0.65476476, 0.48293832], [0.92609759, 0.65099658, 0.48060392], [0.925747, 0.64722729, 0.47827244], [0.92539502, 0.64345456, 0.47594352], [0.92503106, 0.6396848, 0.47361782], [0.92466877, 0.6359095, 0.47129427], [0.92429828, 0.63213463, 0.46897349], [0.92392172, 0.62835879, 0.46665526], [0.92354597, 0.62457749, 0.46433898], [0.9231622, 0.6207962, 0.46202524], [0.92277222, 0.61701365, 0.45971384], [0.92237978, 0.61322733, 0.45740444], [0.92198615, 0.60943622, 0.45509686], [0.92158735, 0.60564276, 0.45279137], [0.92118373, 0.60184659, 0.45048789], [0.92077582, 0.59804722, 0.44818634], [0.92036413, 0.59424414, 0.44588663], [0.91994924, 0.5904368, 0.44358868], [0.91952943, 0.58662619, 0.4412926], [0.91910675, 0.58281075, 0.43899817], [0.91868096, 0.57899046, 0.4367054], [0.91825103, 0.57516584, 0.43441436], [0.91781857, 0.57133556, 0.43212486], [0.9173814, 0.56750099, 0.4298371], [0.91694139, 0.56366058, 0.42755089], [0.91649756, 0.55981483, 0.42526631], [0.91604942, 0.55596387, 0.42298339], [0.9155979, 0.55210684, 0.42070204], [0.9151409, 0.54824485, 0.4184247], [0.91466138, 0.54438817, 0.41617858], [0.91416896, 0.54052962, 0.41396347], [0.91366559, 0.53666778, 0.41177769], [0.91315173, 0.53280208, 0.40962196], [0.91262605, 0.52893336, 0.40749715], [0.91208866, 0.52506133, 0.40540404], [0.91153952, 0.52118582, 0.40334346], [0.91097732, 0.51730767, 0.4013163], [0.910403, 0.51342591, 0.39932342], [0.90981494, 0.50954168, 0.39736571], [0.90921368, 0.5056543, 0.39544411], [0.90859797, 0.50176463, 0.39355952], [0.90796841, 0.49787195, 0.39171297], [0.90732341, 0.4939774, 0.38990532], [0.90666382, 0.49008006, 0.38813773], [0.90598815, 0.486181, 0.38641107], [0.90529624, 0.48228017, 0.38472641], [0.90458808, 0.47837738, 0.38308489], [0.90386248, 0.47447348, 0.38148746], [0.90311921, 0.4705685, 0.37993524], [0.90235809, 0.46666239, 0.37842943], [0.90157824, 0.46275577, 0.37697105], [0.90077904, 0.45884905, 0.37556121], [0.89995995, 0.45494253, 0.37420106], [0.89912041, 0.4510366, 0.37289175], [0.8982602, 0.44713126, 0.37163458], [0.89737819, 0.44322747, 0.37043052], [0.89647387, 0.43932557, 0.36928078], [0.89554477, 0.43542759, 0.36818855], [0.89458871, 0.4315354, 0.36715654], [0.89360794, 0.42764714, 0.36618273], [0.89260152, 0.42376366, 0.36526813], [0.8915687, 0.41988565, 0.36441384], [0.89050882, 0.41601371, 0.36362102], [0.8894159, 0.41215334, 0.36289639], [0.888292, 0.40830288, 0.36223756], [0.88713784, 0.40446193, 0.36164328], [0.88595253, 0.40063149, 0.36111438], [0.88473115, 0.39681635, 0.3606566], [0.88347246, 0.39301805, 0.36027074], [0.88217931, 0.38923439, 0.35995244], [0.880851, 0.38546632, 0.35970244], [0.87947728, 0.38172422, 0.35953127], [0.87806542, 0.37800172, 0.35942941], [0.87661509, 0.37429964, 0.35939659], [0.87511668, 0.37062819, 0.35944178], [0.87357554, 0.36698279, 0.35955811], [0.87199254, 0.3633634, 0.35974223], [0.87035691, 0.35978174, 0.36000516], [0.86867647, 0.35623087, 0.36033559], [0.86694949, 0.35271349, 0.36073358], [0.86516775, 0.34923921, 0.36120624], [0.86333996, 0.34580008, 0.36174113], [0.86145909, 0.3424046, 0.36234402], [0.85952586, 0.33905327, 0.36301129], [0.85754536, 0.33574168, 0.36373567], [0.855514, 0.33247568, 0.36451271], [0.85344392, 0.32924217, 0.36533344], [0.8513284, 0.32604977, 0.36620106], [0.84916723, 0.32289973, 0.36711424], [0.84696243, 0.31979068, 0.36806976], [0.84470627, 0.31673295, 0.36907066], [0.84240761, 0.31371695, 0.37010969], [0.84005337, 0.31075974, 0.37119284], [0.83765537, 0.30784814, 0.3723105], [0.83520234, 0.30499724, 0.37346726], [0.83270291, 0.30219766, 0.37465552], [0.83014895, 0.29946081, 0.37587769], [0.82754694, 0.29677989, 0.37712733], [0.82489111, 0.29416352, 0.37840532], [0.82218644, 0.29160665, 0.37970606], [0.81942908, 0.28911553, 0.38102921], [0.81662276, 0.28668665, 0.38236999], [0.81376555, 0.28432371, 0.383727], [0.81085964, 0.28202508, 0.38509649], [0.8079055, 0.27979128, 0.38647583], [0.80490309, 0.27762348, 0.3878626], [0.80185613, 0.2755178, 0.38925253], [0.79876118, 0.27347974, 0.39064559], [0.79562644, 0.27149928, 0.39203532], [0.79244362, 0.2695883, 0.39342447], [0.78922456, 0.26773176, 0.3948046], [0.78596161, 0.26594053, 0.39617873], [0.7826624, 0.26420493, 0.39754146], [0.77932717, 0.26252522, 0.39889102], [0.77595363, 0.2609049, 0.4002279], [0.77254999, 0.25933319, 0.40154704], [0.76911107, 0.25781758, 0.40284959], [0.76564158, 0.25635173, 0.40413341], [0.76214598, 0.25492998, 0.40539471], [0.75861834, 0.25356035, 0.40663694], [0.75506533, 0.25223402, 0.40785559], [0.75148963, 0.2509473, 0.40904966], [0.74788835, 0.24970413, 0.41022028], [0.74426345, 0.24850191, 0.41136599], [0.74061927, 0.24733457, 0.41248516], [0.73695678, 0.24620072, 0.41357737], [0.73327278, 0.24510469, 0.41464364], [0.72957096, 0.24404127, 0.4156828], [0.72585394, 0.24300672, 0.41669383], [0.7221226, 0.24199971, 0.41767651], [0.71837612, 0.24102046, 0.41863486], [0.71463236, 0.24004289, 0.41956983], [0.7108932, 0.23906316, 0.42048681], [0.70715842, 0.23808142, 0.42138647], [0.70342811, 0.2370976, 0.42226844], [0.69970218, 0.23611179, 0.42313282], [0.69598055, 0.2351247, 0.42397678], [0.69226314, 0.23413578, 0.42480327], [0.68854988, 0.23314511, 0.42561234], [0.68484064, 0.23215279, 0.42640419], [0.68113541, 0.23115942, 0.42717615], [0.67743412, 0.23016472, 0.42792989], [0.67373662, 0.22916861, 0.42866642], [0.67004287, 0.22817117, 0.42938576], [0.66635279, 0.22717328, 0.43008427], [0.66266621, 0.22617435, 0.43076552], [0.65898313, 0.22517434, 0.43142956], [0.65530349, 0.22417381, 0.43207427], [0.65162696, 0.22317307, 0.4327001], [0.64795375, 0.22217149, 0.43330852], [0.64428351, 0.22116972, 0.43389854], [0.64061624, 0.22016818, 0.43446845], [0.63695183, 0.21916625, 0.43502123], [0.63329016, 0.21816454, 0.43555493], [0.62963102, 0.2171635, 0.43606881], [0.62597451, 0.21616235, 0.43656529], [0.62232019, 0.21516239, 0.43704153], [0.61866821, 0.21416307, 0.43749868], [0.61501835, 0.21316435, 0.43793808], [0.61137029, 0.21216761, 0.4383556], [0.60772426, 0.2111715, 0.43875552], [0.60407977, 0.21017746, 0.43913439], [0.60043678, 0.20918503, 0.43949412], [0.59679524, 0.20819447, 0.43983393], [0.59315487, 0.20720639, 0.44015254], [0.58951566, 0.20622027, 0.44045213], [0.58587715, 0.20523751, 0.44072926], [0.5822395, 0.20425693, 0.44098758], [0.57860222, 0.20328034, 0.44122241], [0.57496549, 0.20230637, 0.44143805], [0.57132875, 0.20133689, 0.4416298], [0.56769215, 0.20037071, 0.44180142], [0.5640552, 0.19940936, 0.44194923], [0.56041794, 0.19845221, 0.44207535], [0.55678004, 0.1975, 0.44217824], [0.55314129, 0.19655316, 0.44225723], [0.54950166, 0.19561118, 0.44231412], [0.54585987, 0.19467771, 0.44234111], [0.54221157, 0.19375869, 0.44233698], [0.5385549, 0.19285696, 0.44229959], [0.5348913, 0.19197036, 0.44222958], [0.53122177, 0.1910974, 0.44212735], [0.52754464, 0.19024042, 0.44199159], [0.52386353, 0.18939409, 0.44182449], [0.52017476, 0.18856368, 0.44162345], [0.51648277, 0.18774266, 0.44139128], [0.51278481, 0.18693492, 0.44112605], [0.50908361, 0.18613639, 0.4408295], [0.50537784, 0.18534893, 0.44050064], [0.50166912, 0.18457008, 0.44014054], [0.49795686, 0.18380056, 0.43974881], [0.49424218, 0.18303865, 0.43932623], [0.49052472, 0.18228477, 0.43887255], [0.48680565, 0.1815371, 0.43838867], [0.48308419, 0.18079663, 0.43787408], [0.47936222, 0.18006056, 0.43733022], [0.47563799, 0.17933127, 0.43675585], [0.47191466, 0.17860416, 0.43615337], [0.46818879, 0.17788392, 0.43552047], [0.46446454, 0.17716458, 0.43486036], [0.46073893, 0.17645017, 0.43417097], [0.45701462, 0.17573691, 0.43345429], [0.45329097, 0.17502549, 0.43271025], [0.44956744, 0.17431649, 0.4319386], [0.44584668, 0.17360625, 0.43114133], [0.44212538, 0.17289906, 0.43031642], [0.43840678, 0.17219041, 0.42946642], [0.43469046, 0.17148074, 0.42859124], [0.4309749, 0.17077192, 0.42769008], [0.42726297, 0.17006003, 0.42676519], [0.42355299, 0.16934709, 0.42581586], [0.41984535, 0.16863258, 0.42484219], [0.41614149, 0.16791429, 0.42384614], [0.41244029, 0.16719372, 0.42282661], [0.40874177, 0.16647061, 0.42178429], [0.40504765, 0.16574261, 0.42072062], [0.401357, 0.16501079, 0.41963528], [0.397669, 0.16427607, 0.418528], [0.39398585, 0.16353554, 0.41740053], [0.39030735, 0.16278924, 0.41625344], [0.3866314, 0.16203977, 0.41508517], [0.38295904, 0.16128519, 0.41389849], [0.37928736, 0.16052483, 0.41270599], [0.37562649, 0.15974704, 0.41151182], [0.37197803, 0.15895049, 0.41031532], [0.36833779, 0.15813871, 0.40911916], [0.36470944, 0.15730861, 0.40792149], [0.36109117, 0.15646169, 0.40672362], [0.35748213, 0.15559861, 0.40552633], [0.353885, 0.15471714, 0.40432831], [0.35029682, 0.15381967, 0.4031316], [0.34671861, 0.1529053, 0.40193587], [0.34315191, 0.15197275, 0.40074049], [0.33959331, 0.15102466, 0.3995478], [0.33604378, 0.15006017, 0.39835754], [0.33250529, 0.14907766, 0.39716879], [0.32897621, 0.14807831, 0.39598285], [0.3254559, 0.14706248, 0.39480044], [0.32194567, 0.14602909, 0.39362106], [0.31844477, 0.14497857, 0.39244549], [0.31494974, 0.14391333, 0.39127626], [0.31146605, 0.14282918, 0.39011024], [0.30798857, 0.1417297, 0.38895105], [0.30451661, 0.14061515, 0.38779953], [0.30105136, 0.13948445, 0.38665531], [0.2975886, 0.1383403, 0.38552159], [0.29408557, 0.13721193, 0.38442775] ] _crest_lut = [ [0.6468274, 0.80289262, 0.56592265], [0.64233318, 0.80081141, 0.56639461], [0.63791969, 0.7987162, 0.56674976], [0.6335316, 0.79661833, 0.56706128], [0.62915226, 0.7945212, 0.56735066], [0.62477862, 0.79242543, 0.56762143], [0.62042003, 0.79032918, 0.56786129], [0.61606327, 0.78823508, 0.56808666], [0.61171322, 0.78614216, 0.56829092], [0.60736933, 0.78405055, 0.56847436], [0.60302658, 0.78196121, 0.56864272], [0.59868708, 0.77987374, 0.56879289], [0.59435366, 0.77778758, 0.56892099], [0.59001953, 0.77570403, 0.56903477], [0.58568753, 0.77362254, 0.56913028], [0.58135593, 0.77154342, 0.56920908], [0.57702623, 0.76946638, 0.56926895], [0.57269165, 0.76739266, 0.5693172], [0.56835934, 0.76532092, 0.56934507], [0.56402533, 0.76325185, 0.56935664], [0.55968429, 0.76118643, 0.56935732], [0.55534159, 0.75912361, 0.56934052], [0.55099572, 0.75706366, 0.56930743], [0.54664626, 0.75500662, 0.56925799], [0.54228969, 0.75295306, 0.56919546], [0.53792417, 0.75090328, 0.56912118], [0.53355172, 0.74885687, 0.5690324], [0.52917169, 0.74681387, 0.56892926], [0.52478243, 0.74477453, 0.56881287], [0.52038338, 0.74273888, 0.56868323], [0.5159739, 0.74070697, 0.56854039], [0.51155269, 0.73867895, 0.56838507], [0.50711872, 0.73665492, 0.56821764], [0.50267118, 0.73463494, 0.56803826], [0.49822926, 0.73261388, 0.56785146], [0.49381422, 0.73058524, 0.56767484], [0.48942421, 0.72854938, 0.56751036], [0.48505993, 0.72650623, 0.56735752], [0.48072207, 0.72445575, 0.56721583], [0.4764113, 0.72239788, 0.56708475], [0.47212827, 0.72033258, 0.56696376], [0.46787361, 0.71825983, 0.56685231], [0.46364792, 0.71617961, 0.56674986], [0.45945271, 0.71409167, 0.56665625], [0.45528878, 0.71199595, 0.56657103], [0.45115557, 0.70989276, 0.5664931], [0.44705356, 0.70778212, 0.56642189], [0.44298321, 0.70566406, 0.56635683], [0.43894492, 0.70353863, 0.56629734], [0.43493911, 0.70140588, 0.56624286], [0.43096612, 0.69926587, 0.5661928], [0.42702625, 0.69711868, 0.56614659], [0.42311977, 0.69496438, 0.56610368], [0.41924689, 0.69280308, 0.56606355], [0.41540778, 0.69063486, 0.56602564], [0.41160259, 0.68845984, 0.56598944], [0.40783143, 0.68627814, 0.56595436], [0.40409434, 0.68408988, 0.56591994], [0.40039134, 0.68189518, 0.56588564], [0.39672238, 0.6796942, 0.56585103], [0.39308781, 0.67748696, 0.56581581], [0.38949137, 0.67527276, 0.56578084], [0.38592889, 0.67305266, 0.56574422], [0.38240013, 0.67082685, 0.56570561], [0.37890483, 0.66859548, 0.56566462], [0.37544276, 0.66635871, 0.56562081], [0.37201365, 0.66411673, 0.56557372], [0.36861709, 0.6618697, 0.5655231], [0.36525264, 0.65961782, 0.56546873], [0.36191986, 0.65736125, 0.56541032], [0.35861935, 0.65509998, 0.56534768], [0.35535621, 0.65283302, 0.56528211], [0.35212361, 0.65056188, 0.56521171], [0.34892097, 0.64828676, 0.56513633], [0.34574785, 0.64600783, 0.56505539], [0.34260357, 0.64372528, 0.5649689], [0.33948744, 0.64143931, 0.56487679], [0.33639887, 0.6391501, 0.56477869], [0.33334501, 0.63685626, 0.56467661], [0.33031952, 0.63455911, 0.564569], [0.3273199, 0.63225924, 0.56445488], [0.32434526, 0.62995682, 0.56433457], [0.32139487, 0.62765201, 0.56420795], [0.31846807, 0.62534504, 0.56407446], [0.3155731, 0.62303426, 0.56393695], [0.31270304, 0.62072111, 0.56379321], [0.30985436, 0.61840624, 0.56364307], [0.30702635, 0.61608984, 0.56348606], [0.30421803, 0.61377205, 0.56332267], [0.30143611, 0.61145167, 0.56315419], [0.29867863, 0.60912907, 0.56298054], [0.29593872, 0.60680554, 0.56280022], [0.29321538, 0.60448121, 0.56261376], [0.2905079, 0.60215628, 0.56242036], [0.28782827, 0.5998285, 0.56222366], [0.28516521, 0.59749996, 0.56202093], [0.28251558, 0.59517119, 0.56181204], [0.27987847, 0.59284232, 0.56159709], [0.27726216, 0.59051189, 0.56137785], [0.27466434, 0.58818027, 0.56115433], [0.2720767, 0.58584893, 0.56092486], [0.26949829, 0.58351797, 0.56068983], [0.26693801, 0.58118582, 0.56045121], [0.26439366, 0.57885288, 0.56020858], [0.26185616, 0.57652063, 0.55996077], [0.25932459, 0.57418919, 0.55970795], [0.25681303, 0.57185614, 0.55945297], [0.25431024, 0.56952337, 0.55919385], [0.25180492, 0.56719255, 0.5589305], [0.24929311, 0.56486397, 0.5586654], [0.24678356, 0.56253666, 0.55839491], [0.24426587, 0.56021153, 0.55812473], [0.24174022, 0.55788852, 0.55785448], [0.23921167, 0.55556705, 0.55758211], [0.23668315, 0.55324675, 0.55730676], [0.23414742, 0.55092825, 0.55703167], [0.23160473, 0.54861143, 0.5567573], [0.22905996, 0.54629572, 0.55648168], [0.22651648, 0.54398082, 0.5562029], [0.22396709, 0.54166721, 0.55592542], [0.22141221, 0.53935481, 0.55564885], [0.21885269, 0.53704347, 0.55537294], [0.21629986, 0.53473208, 0.55509319], [0.21374297, 0.53242154, 0.5548144], [0.21118255, 0.53011166, 0.55453708], [0.2086192, 0.52780237, 0.55426067], [0.20605624, 0.52549322, 0.55398479], [0.20350004, 0.5231837, 0.55370601], [0.20094292, 0.52087429, 0.55342884], [0.19838567, 0.51856489, 0.55315283], [0.19582911, 0.51625531, 0.55287818], [0.19327413, 0.51394542, 0.55260469], [0.19072933, 0.51163448, 0.5523289], [0.18819045, 0.50932268, 0.55205372], [0.18565609, 0.50701014, 0.55177937], [0.18312739, 0.50469666, 0.55150597], [0.18060561, 0.50238204, 0.55123374], [0.178092, 0.50006616, 0.55096224], [0.17558808, 0.49774882, 0.55069118], [0.17310341, 0.49542924, 0.5504176], [0.17063111, 0.49310789, 0.55014445], [0.1681728, 0.49078458, 0.54987159], [0.1657302, 0.48845913, 0.54959882], [0.16330517, 0.48613135, 0.54932605], [0.16089963, 0.48380104, 0.54905306], [0.15851561, 0.48146803, 0.54877953], [0.15615526, 0.47913212, 0.54850526], [0.15382083, 0.47679313, 0.54822991], [0.15151471, 0.47445087, 0.54795318], [0.14924112, 0.47210502, 0.54767411], [0.1470032, 0.46975537, 0.54739226], [0.14480101, 0.46740187, 0.54710832], [0.14263736, 0.46504434, 0.54682188], [0.14051521, 0.46268258, 0.54653253], [0.13843761, 0.46031639, 0.54623985], [0.13640774, 0.45794558, 0.5459434], [0.13442887, 0.45556994, 0.54564272], [0.1325044, 0.45318928, 0.54533736], [0.13063777, 0.4508034, 0.54502674], [0.12883252, 0.44841211, 0.5447104], [0.12709242, 0.44601517, 0.54438795], [0.1254209, 0.44361244, 0.54405855], [0.12382162, 0.44120373, 0.54372156], [0.12229818, 0.43878887, 0.54337634], [0.12085453, 0.4363676, 0.54302253], [0.11949938, 0.43393955, 0.54265715], [0.11823166, 0.43150478, 0.54228104], [0.11705496, 0.42906306, 0.54189388], [0.115972, 0.42661431, 0.54149449], [0.11498598, 0.42415835, 0.54108222], [0.11409965, 0.42169502, 0.54065622], [0.11331533, 0.41922424, 0.5402155], [0.11263542, 0.41674582, 0.53975931], [0.1120615, 0.4142597, 0.53928656], [0.11159738, 0.41176567, 0.53879549], [0.11125248, 0.40926325, 0.53828203], [0.11101698, 0.40675289, 0.53774864], [0.11089152, 0.40423445, 0.53719455], [0.11085121, 0.4017095, 0.53662425], [0.11087217, 0.39917938, 0.53604354], [0.11095515, 0.39664394, 0.53545166], [0.11110676, 0.39410282, 0.53484509], [0.11131735, 0.39155635, 0.53422678], [0.11158595, 0.38900446, 0.53359634], [0.11191139, 0.38644711, 0.5329534], [0.11229224, 0.38388426, 0.53229748], [0.11273683, 0.38131546, 0.53162393], [0.11323438, 0.37874109, 0.53093619], [0.11378271, 0.37616112, 0.53023413], [0.11437992, 0.37357557, 0.52951727], [0.11502681, 0.37098429, 0.52878396], [0.11572661, 0.36838709, 0.52803124], [0.11646936, 0.36578429, 0.52726234], [0.11725299, 0.3631759, 0.52647685], [0.1180755, 0.36056193, 0.52567436], [0.1189438, 0.35794203, 0.5248497], [0.11984752, 0.35531657, 0.52400649], [0.1207833, 0.35268564, 0.52314492], [0.12174895, 0.35004927, 0.52226461], [0.12274959, 0.34740723, 0.52136104], [0.12377809, 0.34475975, 0.52043639], [0.12482961, 0.34210702, 0.51949179], [0.125902, 0.33944908, 0.51852688], [0.12699998, 0.33678574, 0.51753708], [0.12811691, 0.33411727, 0.51652464], [0.12924811, 0.33144384, 0.51549084], [0.13039157, 0.32876552, 0.51443538], [0.13155228, 0.32608217, 0.51335321], [0.13272282, 0.32339407, 0.51224759], [0.13389954, 0.32070138, 0.51111946], [0.13508064, 0.31800419, 0.50996862], [0.13627149, 0.31530238, 0.50878942], [0.13746376, 0.31259627, 0.50758645], [0.13865499, 0.30988598, 0.50636017], [0.13984364, 0.30717161, 0.50511042], [0.14103515, 0.30445309, 0.50383119], [0.14222093, 0.30173071, 0.50252813], [0.14339946, 0.2990046, 0.50120127], [0.14456941, 0.29627483, 0.49985054], [0.14573579, 0.29354139, 0.49847009], [0.14689091, 0.29080452, 0.49706566], [0.1480336, 0.28806432, 0.49563732], [0.1491628, 0.28532086, 0.49418508], [0.15028228, 0.28257418, 0.49270402], [0.15138673, 0.27982444, 0.49119848], [0.15247457, 0.27707172, 0.48966925], [0.15354487, 0.2743161, 0.48811641], [0.15459955, 0.27155765, 0.4865371], [0.15563716, 0.26879642, 0.4849321], [0.1566572, 0.26603191, 0.48330429], [0.15765823, 0.26326032, 0.48167456], [0.15862147, 0.26048295, 0.48005785], [0.15954301, 0.25770084, 0.47845341], [0.16043267, 0.25491144, 0.4768626], [0.16129262, 0.25211406, 0.4752857], [0.1621119, 0.24931169, 0.47372076], [0.16290577, 0.24649998, 0.47217025], [0.16366819, 0.24368054, 0.47063302], [0.1644021, 0.24085237, 0.46910949], [0.16510882, 0.2380149, 0.46759982], [0.16579015, 0.23516739, 0.46610429], [0.1664433, 0.2323105, 0.46462219], [0.16707586, 0.22944155, 0.46315508], [0.16768475, 0.22656122, 0.46170223], [0.16826815, 0.22366984, 0.46026308], [0.16883174, 0.22076514, 0.45883891], [0.16937589, 0.21784655, 0.45742976], [0.16990129, 0.21491339, 0.45603578], [0.1704074, 0.21196535, 0.45465677], [0.17089473, 0.20900176, 0.4532928], [0.17136819, 0.20602012, 0.45194524], [0.17182683, 0.20302012, 0.45061386], [0.17227059, 0.20000106, 0.44929865], [0.17270583, 0.19695949, 0.44800165], [0.17313804, 0.19389201, 0.44672488], [0.17363177, 0.19076859, 0.44549087] ] _lut_dict = dict( rocket=_rocket_lut, mako=_mako_lut, icefire=_icefire_lut, vlag=_vlag_lut, flare=_flare_lut, crest=_crest_lut, ) for _name, _lut in _lut_dict.items(): _cmap = colors.ListedColormap(_lut, _name) locals()[_name] = _cmap _cmap_r = colors.ListedColormap(_lut[::-1], _name + "_r") locals()[_name + "_r"] = _cmap_r register_colormap(_name, _cmap) register_colormap(_name + "_r", _cmap_r) del colors, register_colormap ================================================ FILE: seaborn/colors/__init__.py ================================================ from .xkcd_rgb import xkcd_rgb # noqa: F401 from .crayons import crayons # noqa: F401 ================================================ FILE: seaborn/colors/crayons.py ================================================ crayons = {'Almond': '#EFDECD', 'Antique Brass': '#CD9575', 'Apricot': '#FDD9B5', 'Aquamarine': '#78DBE2', 'Asparagus': '#87A96B', 'Atomic Tangerine': '#FFA474', 'Banana Mania': '#FAE7B5', 'Beaver': '#9F8170', 'Bittersweet': '#FD7C6E', 'Black': '#000000', 'Blue': '#1F75FE', 'Blue Bell': '#A2A2D0', 'Blue Green': '#0D98BA', 'Blue Violet': '#7366BD', 'Blush': '#DE5D83', 'Brick Red': '#CB4154', 'Brown': '#B4674D', 'Burnt Orange': '#FF7F49', 'Burnt Sienna': '#EA7E5D', 'Cadet Blue': '#B0B7C6', 'Canary': '#FFFF99', 'Caribbean Green': '#00CC99', 'Carnation Pink': '#FFAACC', 'Cerise': '#DD4492', 'Cerulean': '#1DACD6', 'Chestnut': '#BC5D58', 'Copper': '#DD9475', 'Cornflower': '#9ACEEB', 'Cotton Candy': '#FFBCD9', 'Dandelion': '#FDDB6D', 'Denim': '#2B6CC4', 'Desert Sand': '#EFCDB8', 'Eggplant': '#6E5160', 'Electric Lime': '#CEFF1D', 'Fern': '#71BC78', 'Forest Green': '#6DAE81', 'Fuchsia': '#C364C5', 'Fuzzy Wuzzy': '#CC6666', 'Gold': '#E7C697', 'Goldenrod': '#FCD975', 'Granny Smith Apple': '#A8E4A0', 'Gray': '#95918C', 'Green': '#1CAC78', 'Green Yellow': '#F0E891', 'Hot Magenta': '#FF1DCE', 'Inchworm': '#B2EC5D', 'Indigo': '#5D76CB', 'Jazzberry Jam': '#CA3767', 'Jungle Green': '#3BB08F', 'Laser Lemon': '#FEFE22', 'Lavender': '#FCB4D5', 'Macaroni and Cheese': '#FFBD88', 'Magenta': '#F664AF', 'Mahogany': '#CD4A4C', 'Manatee': '#979AAA', 'Mango Tango': '#FF8243', 'Maroon': '#C8385A', 'Mauvelous': '#EF98AA', 'Melon': '#FDBCB4', 'Midnight Blue': '#1A4876', 'Mountain Meadow': '#30BA8F', 'Navy Blue': '#1974D2', 'Neon Carrot': '#FFA343', 'Olive Green': '#BAB86C', 'Orange': '#FF7538', 'Orchid': '#E6A8D7', 'Outer Space': '#414A4C', 'Outrageous Orange': '#FF6E4A', 'Pacific Blue': '#1CA9C9', 'Peach': '#FFCFAB', 'Periwinkle': '#C5D0E6', 'Piggy Pink': '#FDDDE6', 'Pine Green': '#158078', 'Pink Flamingo': '#FC74FD', 'Pink Sherbert': '#F78FA7', 'Plum': '#8E4585', 'Purple Heart': '#7442C8', "Purple Mountains' Majesty": '#9D81BA', 'Purple Pizzazz': '#FE4EDA', 'Radical Red': '#FF496C', 'Raw Sienna': '#D68A59', 'Razzle Dazzle Rose': '#FF48D0', 'Razzmatazz': '#E3256B', 'Red': '#EE204D', 'Red Orange': '#FF5349', 'Red Violet': '#C0448F', "Robin's Egg Blue": '#1FCECB', 'Royal Purple': '#7851A9', 'Salmon': '#FF9BAA', 'Scarlet': '#FC2847', "Screamin' Green": '#76FF7A', 'Sea Green': '#93DFB8', 'Sepia': '#A5694F', 'Shadow': '#8A795D', 'Shamrock': '#45CEA2', 'Shocking Pink': '#FB7EFD', 'Silver': '#CDC5C2', 'Sky Blue': '#80DAEB', 'Spring Green': '#ECEABE', 'Sunglow': '#FFCF48', 'Sunset Orange': '#FD5E53', 'Tan': '#FAA76C', 'Tickle Me Pink': '#FC89AC', 'Timberwolf': '#DBD7D2', 'Tropical Rain Forest': '#17806D', 'Tumbleweed': '#DEAA88', 'Turquoise Blue': '#77DDE7', 'Unmellow Yellow': '#FFFF66', 'Violet (Purple)': '#926EAE', 'Violet Red': '#F75394', 'Vivid Tangerine': '#FFA089', 'Vivid Violet': '#8F509D', 'White': '#FFFFFF', 'Wild Blue Yonder': '#A2ADD0', 'Wild Strawberry': '#FF43A4', 'Wild Watermelon': '#FC6C85', 'Wisteria': '#CDA4DE', 'Yellow': '#FCE883', 'Yellow Green': '#C5E384', 'Yellow Orange': '#FFAE42'} ================================================ FILE: seaborn/colors/xkcd_rgb.py ================================================ xkcd_rgb = {'acid green': '#8ffe09', 'adobe': '#bd6c48', 'algae': '#54ac68', 'algae green': '#21c36f', 'almost black': '#070d0d', 'amber': '#feb308', 'amethyst': '#9b5fc0', 'apple': '#6ecb3c', 'apple green': '#76cd26', 'apricot': '#ffb16d', 'aqua': '#13eac9', 'aqua blue': '#02d8e9', 'aqua green': '#12e193', 'aqua marine': '#2ee8bb', 'aquamarine': '#04d8b2', 'army green': '#4b5d16', 'asparagus': '#77ab56', 'aubergine': '#3d0734', 'auburn': '#9a3001', 'avocado': '#90b134', 'avocado green': '#87a922', 'azul': '#1d5dec', 'azure': '#069af3', 'baby blue': '#a2cffe', 'baby green': '#8cff9e', 'baby pink': '#ffb7ce', 'baby poo': '#ab9004', 'baby poop': '#937c00', 'baby poop green': '#8f9805', 'baby puke green': '#b6c406', 'baby purple': '#ca9bf7', 'baby shit brown': '#ad900d', 'baby shit green': '#889717', 'banana': '#ffff7e', 'banana yellow': '#fafe4b', 'barbie pink': '#fe46a5', 'barf green': '#94ac02', 'barney': '#ac1db8', 'barney purple': '#a00498', 'battleship grey': '#6b7c85', 'beige': '#e6daa6', 'berry': '#990f4b', 'bile': '#b5c306', 'black': '#000000', 'bland': '#afa88b', 'blood': '#770001', 'blood orange': '#fe4b03', 'blood red': '#980002', 'blue': '#0343df', 'blue blue': '#2242c7', 'blue green': '#137e6d', 'blue grey': '#607c8e', 'blue purple': '#5729ce', 'blue violet': '#5d06e9', 'blue with a hint of purple': '#533cc6', 'blue/green': '#0f9b8e', 'blue/grey': '#758da3', 'blue/purple': '#5a06ef', 'blueberry': '#464196', 'bluegreen': '#017a79', 'bluegrey': '#85a3b2', 'bluey green': '#2bb179', 'bluey grey': '#89a0b0', 'bluey purple': '#6241c7', 'bluish': '#2976bb', 'bluish green': '#10a674', 'bluish grey': '#748b97', 'bluish purple': '#703be7', 'blurple': '#5539cc', 'blush': '#f29e8e', 'blush pink': '#fe828c', 'booger': '#9bb53c', 'booger green': '#96b403', 'bordeaux': '#7b002c', 'boring green': '#63b365', 'bottle green': '#044a05', 'brick': '#a03623', 'brick orange': '#c14a09', 'brick red': '#8f1402', 'bright aqua': '#0bf9ea', 'bright blue': '#0165fc', 'bright cyan': '#41fdfe', 'bright green': '#01ff07', 'bright lavender': '#c760ff', 'bright light blue': '#26f7fd', 'bright light green': '#2dfe54', 'bright lilac': '#c95efb', 'bright lime': '#87fd05', 'bright lime green': '#65fe08', 'bright magenta': '#ff08e8', 'bright olive': '#9cbb04', 'bright orange': '#ff5b00', 'bright pink': '#fe01b1', 'bright purple': '#be03fd', 'bright red': '#ff000d', 'bright sea green': '#05ffa6', 'bright sky blue': '#02ccfe', 'bright teal': '#01f9c6', 'bright turquoise': '#0ffef9', 'bright violet': '#ad0afd', 'bright yellow': '#fffd01', 'bright yellow green': '#9dff00', 'british racing green': '#05480d', 'bronze': '#a87900', 'brown': '#653700', 'brown green': '#706c11', 'brown grey': '#8d8468', 'brown orange': '#b96902', 'brown red': '#922b05', 'brown yellow': '#b29705', 'brownish': '#9c6d57', 'brownish green': '#6a6e09', 'brownish grey': '#86775f', 'brownish orange': '#cb7723', 'brownish pink': '#c27e79', 'brownish purple': '#76424e', 'brownish red': '#9e3623', 'brownish yellow': '#c9b003', 'browny green': '#6f6c0a', 'browny orange': '#ca6b02', 'bruise': '#7e4071', 'bubble gum pink': '#ff69af', 'bubblegum': '#ff6cb5', 'bubblegum pink': '#fe83cc', 'buff': '#fef69e', 'burgundy': '#610023', 'burnt orange': '#c04e01', 'burnt red': '#9f2305', 'burnt siena': '#b75203', 'burnt sienna': '#b04e0f', 'burnt umber': '#a0450e', 'burnt yellow': '#d5ab09', 'burple': '#6832e3', 'butter': '#ffff81', 'butter yellow': '#fffd74', 'butterscotch': '#fdb147', 'cadet blue': '#4e7496', 'camel': '#c69f59', 'camo': '#7f8f4e', 'camo green': '#526525', 'camouflage green': '#4b6113', 'canary': '#fdff63', 'canary yellow': '#fffe40', 'candy pink': '#ff63e9', 'caramel': '#af6f09', 'carmine': '#9d0216', 'carnation': '#fd798f', 'carnation pink': '#ff7fa7', 'carolina blue': '#8ab8fe', 'celadon': '#befdb7', 'celery': '#c1fd95', 'cement': '#a5a391', 'cerise': '#de0c62', 'cerulean': '#0485d1', 'cerulean blue': '#056eee', 'charcoal': '#343837', 'charcoal grey': '#3c4142', 'chartreuse': '#c1f80a', 'cherry': '#cf0234', 'cherry red': '#f7022a', 'chestnut': '#742802', 'chocolate': '#3d1c02', 'chocolate brown': '#411900', 'cinnamon': '#ac4f06', 'claret': '#680018', 'clay': '#b66a50', 'clay brown': '#b2713d', 'clear blue': '#247afd', 'cloudy blue': '#acc2d9', 'cobalt': '#1e488f', 'cobalt blue': '#030aa7', 'cocoa': '#875f42', 'coffee': '#a6814c', 'cool blue': '#4984b8', 'cool green': '#33b864', 'cool grey': '#95a3a6', 'copper': '#b66325', 'coral': '#fc5a50', 'coral pink': '#ff6163', 'cornflower': '#6a79f7', 'cornflower blue': '#5170d7', 'cranberry': '#9e003a', 'cream': '#ffffc2', 'creme': '#ffffb6', 'crimson': '#8c000f', 'custard': '#fffd78', 'cyan': '#00ffff', 'dandelion': '#fedf08', 'dark': '#1b2431', 'dark aqua': '#05696b', 'dark aquamarine': '#017371', 'dark beige': '#ac9362', 'dark blue': '#00035b', 'dark blue green': '#005249', 'dark blue grey': '#1f3b4d', 'dark brown': '#341c02', 'dark coral': '#cf524e', 'dark cream': '#fff39a', 'dark cyan': '#0a888a', 'dark forest green': '#002d04', 'dark fuchsia': '#9d0759', 'dark gold': '#b59410', 'dark grass green': '#388004', 'dark green': '#033500', 'dark green blue': '#1f6357', 'dark grey': '#363737', 'dark grey blue': '#29465b', 'dark hot pink': '#d90166', 'dark indigo': '#1f0954', 'dark khaki': '#9b8f55', 'dark lavender': '#856798', 'dark lilac': '#9c6da5', 'dark lime': '#84b701', 'dark lime green': '#7ebd01', 'dark magenta': '#960056', 'dark maroon': '#3c0008', 'dark mauve': '#874c62', 'dark mint': '#48c072', 'dark mint green': '#20c073', 'dark mustard': '#a88905', 'dark navy': '#000435', 'dark navy blue': '#00022e', 'dark olive': '#373e02', 'dark olive green': '#3c4d03', 'dark orange': '#c65102', 'dark pastel green': '#56ae57', 'dark peach': '#de7e5d', 'dark periwinkle': '#665fd1', 'dark pink': '#cb416b', 'dark plum': '#3f012c', 'dark purple': '#35063e', 'dark red': '#840000', 'dark rose': '#b5485d', 'dark royal blue': '#02066f', 'dark sage': '#598556', 'dark salmon': '#c85a53', 'dark sand': '#a88f59', 'dark sea green': '#11875d', 'dark seafoam': '#1fb57a', 'dark seafoam green': '#3eaf76', 'dark sky blue': '#448ee4', 'dark slate blue': '#214761', 'dark tan': '#af884a', 'dark taupe': '#7f684e', 'dark teal': '#014d4e', 'dark turquoise': '#045c5a', 'dark violet': '#34013f', 'dark yellow': '#d5b60a', 'dark yellow green': '#728f02', 'darkblue': '#030764', 'darkgreen': '#054907', 'darkish blue': '#014182', 'darkish green': '#287c37', 'darkish pink': '#da467d', 'darkish purple': '#751973', 'darkish red': '#a90308', 'deep aqua': '#08787f', 'deep blue': '#040273', 'deep brown': '#410200', 'deep green': '#02590f', 'deep lavender': '#8d5eb7', 'deep lilac': '#966ebd', 'deep magenta': '#a0025c', 'deep orange': '#dc4d01', 'deep pink': '#cb0162', 'deep purple': '#36013f', 'deep red': '#9a0200', 'deep rose': '#c74767', 'deep sea blue': '#015482', 'deep sky blue': '#0d75f8', 'deep teal': '#00555a', 'deep turquoise': '#017374', 'deep violet': '#490648', 'denim': '#3b638c', 'denim blue': '#3b5b92', 'desert': '#ccad60', 'diarrhea': '#9f8303', 'dirt': '#8a6e45', 'dirt brown': '#836539', 'dirty blue': '#3f829d', 'dirty green': '#667e2c', 'dirty orange': '#c87606', 'dirty pink': '#ca7b80', 'dirty purple': '#734a65', 'dirty yellow': '#cdc50a', 'dodger blue': '#3e82fc', 'drab': '#828344', 'drab green': '#749551', 'dried blood': '#4b0101', 'duck egg blue': '#c3fbf4', 'dull blue': '#49759c', 'dull brown': '#876e4b', 'dull green': '#74a662', 'dull orange': '#d8863b', 'dull pink': '#d5869d', 'dull purple': '#84597e', 'dull red': '#bb3f3f', 'dull teal': '#5f9e8f', 'dull yellow': '#eedc5b', 'dusk': '#4e5481', 'dusk blue': '#26538d', 'dusky blue': '#475f94', 'dusky pink': '#cc7a8b', 'dusky purple': '#895b7b', 'dusky rose': '#ba6873', 'dust': '#b2996e', 'dusty blue': '#5a86ad', 'dusty green': '#76a973', 'dusty lavender': '#ac86a8', 'dusty orange': '#f0833a', 'dusty pink': '#d58a94', 'dusty purple': '#825f87', 'dusty red': '#b9484e', 'dusty rose': '#c0737a', 'dusty teal': '#4c9085', 'earth': '#a2653e', 'easter green': '#8cfd7e', 'easter purple': '#c071fe', 'ecru': '#feffca', 'egg shell': '#fffcc4', 'eggplant': '#380835', 'eggplant purple': '#430541', 'eggshell': '#ffffd4', 'eggshell blue': '#c4fff7', 'electric blue': '#0652ff', 'electric green': '#21fc0d', 'electric lime': '#a8ff04', 'electric pink': '#ff0490', 'electric purple': '#aa23ff', 'emerald': '#01a049', 'emerald green': '#028f1e', 'evergreen': '#05472a', 'faded blue': '#658cbb', 'faded green': '#7bb274', 'faded orange': '#f0944d', 'faded pink': '#de9dac', 'faded purple': '#916e99', 'faded red': '#d3494e', 'faded yellow': '#feff7f', 'fawn': '#cfaf7b', 'fern': '#63a950', 'fern green': '#548d44', 'fire engine red': '#fe0002', 'flat blue': '#3c73a8', 'flat green': '#699d4c', 'fluorescent green': '#08ff08', 'fluro green': '#0aff02', 'foam green': '#90fda9', 'forest': '#0b5509', 'forest green': '#06470c', 'forrest green': '#154406', 'french blue': '#436bad', 'fresh green': '#69d84f', 'frog green': '#58bc08', 'fuchsia': '#ed0dd9', 'gold': '#dbb40c', 'golden': '#f5bf03', 'golden brown': '#b27a01', 'golden rod': '#f9bc08', 'golden yellow': '#fec615', 'goldenrod': '#fac205', 'grape': '#6c3461', 'grape purple': '#5d1451', 'grapefruit': '#fd5956', 'grass': '#5cac2d', 'grass green': '#3f9b0b', 'grassy green': '#419c03', 'green': '#15b01a', 'green apple': '#5edc1f', 'green blue': '#06b48b', 'green brown': '#544e03', 'green grey': '#77926f', 'green teal': '#0cb577', 'green yellow': '#c9ff27', 'green/blue': '#01c08d', 'green/yellow': '#b5ce08', 'greenblue': '#23c48b', 'greenish': '#40a368', 'greenish beige': '#c9d179', 'greenish blue': '#0b8b87', 'greenish brown': '#696112', 'greenish cyan': '#2afeb7', 'greenish grey': '#96ae8d', 'greenish tan': '#bccb7a', 'greenish teal': '#32bf84', 'greenish turquoise': '#00fbb0', 'greenish yellow': '#cdfd02', 'greeny blue': '#42b395', 'greeny brown': '#696006', 'greeny grey': '#7ea07a', 'greeny yellow': '#c6f808', 'grey': '#929591', 'grey blue': '#6b8ba4', 'grey brown': '#7f7053', 'grey green': '#789b73', 'grey pink': '#c3909b', 'grey purple': '#826d8c', 'grey teal': '#5e9b8a', 'grey/blue': '#647d8e', 'grey/green': '#86a17d', 'greyblue': '#77a1b5', 'greyish': '#a8a495', 'greyish blue': '#5e819d', 'greyish brown': '#7a6a4f', 'greyish green': '#82a67d', 'greyish pink': '#c88d94', 'greyish purple': '#887191', 'greyish teal': '#719f91', 'gross green': '#a0bf16', 'gunmetal': '#536267', 'hazel': '#8e7618', 'heather': '#a484ac', 'heliotrope': '#d94ff5', 'highlighter green': '#1bfc06', 'hospital green': '#9be5aa', 'hot green': '#25ff29', 'hot magenta': '#f504c9', 'hot pink': '#ff028d', 'hot purple': '#cb00f5', 'hunter green': '#0b4008', 'ice': '#d6fffa', 'ice blue': '#d7fffe', 'icky green': '#8fae22', 'indian red': '#850e04', 'indigo': '#380282', 'indigo blue': '#3a18b1', 'iris': '#6258c4', 'irish green': '#019529', 'ivory': '#ffffcb', 'jade': '#1fa774', 'jade green': '#2baf6a', 'jungle green': '#048243', 'kelley green': '#009337', 'kelly green': '#02ab2e', 'kermit green': '#5cb200', 'key lime': '#aeff6e', 'khaki': '#aaa662', 'khaki green': '#728639', 'kiwi': '#9cef43', 'kiwi green': '#8ee53f', 'lavender': '#c79fef', 'lavender blue': '#8b88f8', 'lavender pink': '#dd85d7', 'lawn green': '#4da409', 'leaf': '#71aa34', 'leaf green': '#5ca904', 'leafy green': '#51b73b', 'leather': '#ac7434', 'lemon': '#fdff52', 'lemon green': '#adf802', 'lemon lime': '#bffe28', 'lemon yellow': '#fdff38', 'lichen': '#8fb67b', 'light aqua': '#8cffdb', 'light aquamarine': '#7bfdc7', 'light beige': '#fffeb6', 'light blue': '#95d0fc', 'light blue green': '#7efbb3', 'light blue grey': '#b7c9e2', 'light bluish green': '#76fda8', 'light bright green': '#53fe5c', 'light brown': '#ad8150', 'light burgundy': '#a8415b', 'light cyan': '#acfffc', 'light eggplant': '#894585', 'light forest green': '#4f9153', 'light gold': '#fddc5c', 'light grass green': '#9af764', 'light green': '#96f97b', 'light green blue': '#56fca2', 'light greenish blue': '#63f7b4', 'light grey': '#d8dcd6', 'light grey blue': '#9dbcd4', 'light grey green': '#b7e1a1', 'light indigo': '#6d5acf', 'light khaki': '#e6f2a2', 'light lavendar': '#efc0fe', 'light lavender': '#dfc5fe', 'light light blue': '#cafffb', 'light light green': '#c8ffb0', 'light lilac': '#edc8ff', 'light lime': '#aefd6c', 'light lime green': '#b9ff66', 'light magenta': '#fa5ff7', 'light maroon': '#a24857', 'light mauve': '#c292a1', 'light mint': '#b6ffbb', 'light mint green': '#a6fbb2', 'light moss green': '#a6c875', 'light mustard': '#f7d560', 'light navy': '#155084', 'light navy blue': '#2e5a88', 'light neon green': '#4efd54', 'light olive': '#acbf69', 'light olive green': '#a4be5c', 'light orange': '#fdaa48', 'light pastel green': '#b2fba5', 'light pea green': '#c4fe82', 'light peach': '#ffd8b1', 'light periwinkle': '#c1c6fc', 'light pink': '#ffd1df', 'light plum': '#9d5783', 'light purple': '#bf77f6', 'light red': '#ff474c', 'light rose': '#ffc5cb', 'light royal blue': '#3a2efe', 'light sage': '#bcecac', 'light salmon': '#fea993', 'light sea green': '#98f6b0', 'light seafoam': '#a0febf', 'light seafoam green': '#a7ffb5', 'light sky blue': '#c6fcff', 'light tan': '#fbeeac', 'light teal': '#90e4c1', 'light turquoise': '#7ef4cc', 'light urple': '#b36ff6', 'light violet': '#d6b4fc', 'light yellow': '#fffe7a', 'light yellow green': '#ccfd7f', 'light yellowish green': '#c2ff89', 'lightblue': '#7bc8f6', 'lighter green': '#75fd63', 'lighter purple': '#a55af4', 'lightgreen': '#76ff7b', 'lightish blue': '#3d7afd', 'lightish green': '#61e160', 'lightish purple': '#a552e6', 'lightish red': '#fe2f4a', 'lilac': '#cea2fd', 'liliac': '#c48efd', 'lime': '#aaff32', 'lime green': '#89fe05', 'lime yellow': '#d0fe1d', 'lipstick': '#d5174e', 'lipstick red': '#c0022f', 'macaroni and cheese': '#efb435', 'magenta': '#c20078', 'mahogany': '#4a0100', 'maize': '#f4d054', 'mango': '#ffa62b', 'manilla': '#fffa86', 'marigold': '#fcc006', 'marine': '#042e60', 'marine blue': '#01386a', 'maroon': '#650021', 'mauve': '#ae7181', 'medium blue': '#2c6fbb', 'medium brown': '#7f5112', 'medium green': '#39ad48', 'medium grey': '#7d7f7c', 'medium pink': '#f36196', 'medium purple': '#9e43a2', 'melon': '#ff7855', 'merlot': '#730039', 'metallic blue': '#4f738e', 'mid blue': '#276ab3', 'mid green': '#50a747', 'midnight': '#03012d', 'midnight blue': '#020035', 'midnight purple': '#280137', 'military green': '#667c3e', 'milk chocolate': '#7f4e1e', 'mint': '#9ffeb0', 'mint green': '#8fff9f', 'minty green': '#0bf77d', 'mocha': '#9d7651', 'moss': '#769958', 'moss green': '#658b38', 'mossy green': '#638b27', 'mud': '#735c12', 'mud brown': '#60460f', 'mud green': '#606602', 'muddy brown': '#886806', 'muddy green': '#657432', 'muddy yellow': '#bfac05', 'mulberry': '#920a4e', 'murky green': '#6c7a0e', 'mushroom': '#ba9e88', 'mustard': '#ceb301', 'mustard brown': '#ac7e04', 'mustard green': '#a8b504', 'mustard yellow': '#d2bd0a', 'muted blue': '#3b719f', 'muted green': '#5fa052', 'muted pink': '#d1768f', 'muted purple': '#805b87', 'nasty green': '#70b23f', 'navy': '#01153e', 'navy blue': '#001146', 'navy green': '#35530a', 'neon blue': '#04d9ff', 'neon green': '#0cff0c', 'neon pink': '#fe019a', 'neon purple': '#bc13fe', 'neon red': '#ff073a', 'neon yellow': '#cfff04', 'nice blue': '#107ab0', 'night blue': '#040348', 'ocean': '#017b92', 'ocean blue': '#03719c', 'ocean green': '#3d9973', 'ocher': '#bf9b0c', 'ochre': '#bf9005', 'ocre': '#c69c04', 'off blue': '#5684ae', 'off green': '#6ba353', 'off white': '#ffffe4', 'off yellow': '#f1f33f', 'old pink': '#c77986', 'old rose': '#c87f89', 'olive': '#6e750e', 'olive brown': '#645403', 'olive drab': '#6f7632', 'olive green': '#677a04', 'olive yellow': '#c2b709', 'orange': '#f97306', 'orange brown': '#be6400', 'orange pink': '#ff6f52', 'orange red': '#fd411e', 'orange yellow': '#ffad01', 'orangeish': '#fd8d49', 'orangered': '#fe420f', 'orangey brown': '#b16002', 'orangey red': '#fa4224', 'orangey yellow': '#fdb915', 'orangish': '#fc824a', 'orangish brown': '#b25f03', 'orangish red': '#f43605', 'orchid': '#c875c4', 'pale': '#fff9d0', 'pale aqua': '#b8ffeb', 'pale blue': '#d0fefe', 'pale brown': '#b1916e', 'pale cyan': '#b7fffa', 'pale gold': '#fdde6c', 'pale green': '#c7fdb5', 'pale grey': '#fdfdfe', 'pale lavender': '#eecffe', 'pale light green': '#b1fc99', 'pale lilac': '#e4cbff', 'pale lime': '#befd73', 'pale lime green': '#b1ff65', 'pale magenta': '#d767ad', 'pale mauve': '#fed0fc', 'pale olive': '#b9cc81', 'pale olive green': '#b1d27b', 'pale orange': '#ffa756', 'pale peach': '#ffe5ad', 'pale pink': '#ffcfdc', 'pale purple': '#b790d4', 'pale red': '#d9544d', 'pale rose': '#fdc1c5', 'pale salmon': '#ffb19a', 'pale sky blue': '#bdf6fe', 'pale teal': '#82cbb2', 'pale turquoise': '#a5fbd5', 'pale violet': '#ceaefa', 'pale yellow': '#ffff84', 'parchment': '#fefcaf', 'pastel blue': '#a2bffe', 'pastel green': '#b0ff9d', 'pastel orange': '#ff964f', 'pastel pink': '#ffbacd', 'pastel purple': '#caa0ff', 'pastel red': '#db5856', 'pastel yellow': '#fffe71', 'pea': '#a4bf20', 'pea green': '#8eab12', 'pea soup': '#929901', 'pea soup green': '#94a617', 'peach': '#ffb07c', 'peachy pink': '#ff9a8a', 'peacock blue': '#016795', 'pear': '#cbf85f', 'periwinkle': '#8e82fe', 'periwinkle blue': '#8f99fb', 'perrywinkle': '#8f8ce7', 'petrol': '#005f6a', 'pig pink': '#e78ea5', 'pine': '#2b5d34', 'pine green': '#0a481e', 'pink': '#ff81c0', 'pink purple': '#db4bda', 'pink red': '#f5054f', 'pink/purple': '#ef1de7', 'pinkish': '#d46a7e', 'pinkish brown': '#b17261', 'pinkish grey': '#c8aca9', 'pinkish orange': '#ff724c', 'pinkish purple': '#d648d7', 'pinkish red': '#f10c45', 'pinkish tan': '#d99b82', 'pinky': '#fc86aa', 'pinky purple': '#c94cbe', 'pinky red': '#fc2647', 'piss yellow': '#ddd618', 'pistachio': '#c0fa8b', 'plum': '#580f41', 'plum purple': '#4e0550', 'poison green': '#40fd14', 'poo': '#8f7303', 'poo brown': '#885f01', 'poop': '#7f5e00', 'poop brown': '#7a5901', 'poop green': '#6f7c00', 'powder blue': '#b1d1fc', 'powder pink': '#ffb2d0', 'primary blue': '#0804f9', 'prussian blue': '#004577', 'puce': '#a57e52', 'puke': '#a5a502', 'puke brown': '#947706', 'puke green': '#9aae07', 'puke yellow': '#c2be0e', 'pumpkin': '#e17701', 'pumpkin orange': '#fb7d07', 'pure blue': '#0203e2', 'purple': '#7e1e9c', 'purple blue': '#632de9', 'purple brown': '#673a3f', 'purple grey': '#866f85', 'purple pink': '#e03fd8', 'purple red': '#990147', 'purple/blue': '#5d21d0', 'purple/pink': '#d725de', 'purpleish': '#98568d', 'purpleish blue': '#6140ef', 'purpleish pink': '#df4ec8', 'purpley': '#8756e4', 'purpley blue': '#5f34e7', 'purpley grey': '#947e94', 'purpley pink': '#c83cb9', 'purplish': '#94568c', 'purplish blue': '#601ef9', 'purplish brown': '#6b4247', 'purplish grey': '#7a687f', 'purplish pink': '#ce5dae', 'purplish red': '#b0054b', 'purply': '#983fb2', 'purply blue': '#661aee', 'purply pink': '#f075e6', 'putty': '#beae8a', 'racing green': '#014600', 'radioactive green': '#2cfa1f', 'raspberry': '#b00149', 'raw sienna': '#9a6200', 'raw umber': '#a75e09', 'really light blue': '#d4ffff', 'red': '#e50000', 'red brown': '#8b2e16', 'red orange': '#fd3c06', 'red pink': '#fa2a55', 'red purple': '#820747', 'red violet': '#9e0168', 'red wine': '#8c0034', 'reddish': '#c44240', 'reddish brown': '#7f2b0a', 'reddish grey': '#997570', 'reddish orange': '#f8481c', 'reddish pink': '#fe2c54', 'reddish purple': '#910951', 'reddy brown': '#6e1005', 'rich blue': '#021bf9', 'rich purple': '#720058', 'robin egg blue': '#8af1fe', "robin's egg": '#6dedfd', "robin's egg blue": '#98eff9', 'rosa': '#fe86a4', 'rose': '#cf6275', 'rose pink': '#f7879a', 'rose red': '#be013c', 'rosy pink': '#f6688e', 'rouge': '#ab1239', 'royal': '#0c1793', 'royal blue': '#0504aa', 'royal purple': '#4b006e', 'ruby': '#ca0147', 'russet': '#a13905', 'rust': '#a83c09', 'rust brown': '#8b3103', 'rust orange': '#c45508', 'rust red': '#aa2704', 'rusty orange': '#cd5909', 'rusty red': '#af2f0d', 'saffron': '#feb209', 'sage': '#87ae73', 'sage green': '#88b378', 'salmon': '#ff796c', 'salmon pink': '#fe7b7c', 'sand': '#e2ca76', 'sand brown': '#cba560', 'sand yellow': '#fce166', 'sandstone': '#c9ae74', 'sandy': '#f1da7a', 'sandy brown': '#c4a661', 'sandy yellow': '#fdee73', 'sap green': '#5c8b15', 'sapphire': '#2138ab', 'scarlet': '#be0119', 'sea': '#3c9992', 'sea blue': '#047495', 'sea green': '#53fca1', 'seafoam': '#80f9ad', 'seafoam blue': '#78d1b6', 'seafoam green': '#7af9ab', 'seaweed': '#18d17b', 'seaweed green': '#35ad6b', 'sepia': '#985e2b', 'shamrock': '#01b44c', 'shamrock green': '#02c14d', 'shit': '#7f5f00', 'shit brown': '#7b5804', 'shit green': '#758000', 'shocking pink': '#fe02a2', 'sick green': '#9db92c', 'sickly green': '#94b21c', 'sickly yellow': '#d0e429', 'sienna': '#a9561e', 'silver': '#c5c9c7', 'sky': '#82cafc', 'sky blue': '#75bbfd', 'slate': '#516572', 'slate blue': '#5b7c99', 'slate green': '#658d6d', 'slate grey': '#59656d', 'slime green': '#99cc04', 'snot': '#acbb0d', 'snot green': '#9dc100', 'soft blue': '#6488ea', 'soft green': '#6fc276', 'soft pink': '#fdb0c0', 'soft purple': '#a66fb5', 'spearmint': '#1ef876', 'spring green': '#a9f971', 'spruce': '#0a5f38', 'squash': '#f2ab15', 'steel': '#738595', 'steel blue': '#5a7d9a', 'steel grey': '#6f828a', 'stone': '#ada587', 'stormy blue': '#507b9c', 'straw': '#fcf679', 'strawberry': '#fb2943', 'strong blue': '#0c06f7', 'strong pink': '#ff0789', 'sun yellow': '#ffdf22', 'sunflower': '#ffc512', 'sunflower yellow': '#ffda03', 'sunny yellow': '#fff917', 'sunshine yellow': '#fffd37', 'swamp': '#698339', 'swamp green': '#748500', 'tan': '#d1b26f', 'tan brown': '#ab7e4c', 'tan green': '#a9be70', 'tangerine': '#ff9408', 'taupe': '#b9a281', 'tea': '#65ab7c', 'tea green': '#bdf8a3', 'teal': '#029386', 'teal blue': '#01889f', 'teal green': '#25a36f', 'tealish': '#24bca8', 'tealish green': '#0cdc73', 'terra cotta': '#c9643b', 'terracota': '#cb6843', 'terracotta': '#ca6641', 'tiffany blue': '#7bf2da', 'tomato': '#ef4026', 'tomato red': '#ec2d01', 'topaz': '#13bbaf', 'toupe': '#c7ac7d', 'toxic green': '#61de2a', 'tree green': '#2a7e19', 'true blue': '#010fcc', 'true green': '#089404', 'turquoise': '#06c2ac', 'turquoise blue': '#06b1c4', 'turquoise green': '#04f489', 'turtle green': '#75b84f', 'twilight': '#4e518b', 'twilight blue': '#0a437a', 'ugly blue': '#31668a', 'ugly brown': '#7d7103', 'ugly green': '#7a9703', 'ugly pink': '#cd7584', 'ugly purple': '#a442a0', 'ugly yellow': '#d0c101', 'ultramarine': '#2000b1', 'ultramarine blue': '#1805db', 'umber': '#b26400', 'velvet': '#750851', 'vermillion': '#f4320c', 'very dark blue': '#000133', 'very dark brown': '#1d0200', 'very dark green': '#062e03', 'very dark purple': '#2a0134', 'very light blue': '#d5ffff', 'very light brown': '#d3b683', 'very light green': '#d1ffbd', 'very light pink': '#fff4f2', 'very light purple': '#f6cefc', 'very pale blue': '#d6fffe', 'very pale green': '#cffdbc', 'vibrant blue': '#0339f8', 'vibrant green': '#0add08', 'vibrant purple': '#ad03de', 'violet': '#9a0eea', 'violet blue': '#510ac9', 'violet pink': '#fb5ffc', 'violet red': '#a50055', 'viridian': '#1e9167', 'vivid blue': '#152eff', 'vivid green': '#2fef10', 'vivid purple': '#9900fa', 'vomit': '#a2a415', 'vomit green': '#89a203', 'vomit yellow': '#c7c10c', 'warm blue': '#4b57db', 'warm brown': '#964e02', 'warm grey': '#978a84', 'warm pink': '#fb5581', 'warm purple': '#952e8f', 'washed out green': '#bcf5a6', 'water blue': '#0e87cc', 'watermelon': '#fd4659', 'weird green': '#3ae57f', 'wheat': '#fbdd7e', 'white': '#ffffff', 'windows blue': '#3778bf', 'wine': '#80013f', 'wine red': '#7b0323', 'wintergreen': '#20f986', 'wisteria': '#a87dc2', 'yellow': '#ffff14', 'yellow brown': '#b79400', 'yellow green': '#c0fb2d', 'yellow ochre': '#cb9d06', 'yellow orange': '#fcb001', 'yellow tan': '#ffe36e', 'yellow/green': '#c8fd3d', 'yellowgreen': '#bbf90f', 'yellowish': '#faee66', 'yellowish brown': '#9b7a01', 'yellowish green': '#b0dd16', 'yellowish orange': '#ffab0f', 'yellowish tan': '#fcfc81', 'yellowy brown': '#ae8b0c', 'yellowy green': '#bff128'} ================================================ FILE: seaborn/distributions.py ================================================ """Plotting functions for visualizing distributions.""" from numbers import Number from functools import partial import math import textwrap import warnings import numpy as np import pandas as pd import matplotlib as mpl import matplotlib.pyplot as plt import matplotlib.transforms as tx from matplotlib.cbook import normalize_kwargs from matplotlib.colors import to_rgba from matplotlib.collections import LineCollection from ._base import VectorPlotter # We have moved univariate histogram computation over to the new Hist class, # but still use the older Histogram for bivariate computation. from ._statistics import ECDF, Histogram, KDE from ._stats.counting import Hist from .axisgrid import ( FacetGrid, _facet_docs, ) from .utils import ( remove_na, _get_transform_functions, _kde_support, _check_argument, _assign_default_kwargs, _default_color, ) from .palettes import color_palette from .external import husl from .external.kde import gaussian_kde from ._docstrings import ( DocstringComponents, _core_docs, ) __all__ = ["displot", "histplot", "kdeplot", "ecdfplot", "rugplot", "distplot"] # ==================================================================================== # # Module documentation # ==================================================================================== # _dist_params = dict( multiple=""" multiple : {{"layer", "stack", "fill"}} Method for drawing multiple elements when semantic mapping creates subsets. Only relevant with univariate data. """, log_scale=""" log_scale : bool or number, or pair of bools or numbers Set axis scale(s) to log. A single value sets the data axis for any numeric axes in the plot. A pair of values sets each axis independently. Numeric values are interpreted as the desired base (default 10). When `None` or `False`, seaborn defers to the existing Axes scale. """, legend=""" legend : bool If False, suppress the legend for semantic variables. """, cbar=""" cbar : bool If True, add a colorbar to annotate the color mapping in a bivariate plot. Note: Does not currently support plots with a ``hue`` variable well. """, cbar_ax=""" cbar_ax : :class:`matplotlib.axes.Axes` Pre-existing axes for the colorbar. """, cbar_kws=""" cbar_kws : dict Additional parameters passed to :meth:`matplotlib.figure.Figure.colorbar`. """, ) _param_docs = DocstringComponents.from_nested_components( core=_core_docs["params"], facets=DocstringComponents(_facet_docs), dist=DocstringComponents(_dist_params), kde=DocstringComponents.from_function_params(KDE.__init__), hist=DocstringComponents.from_function_params(Histogram.__init__), ecdf=DocstringComponents.from_function_params(ECDF.__init__), ) # ==================================================================================== # # Internal API # ==================================================================================== # class _DistributionPlotter(VectorPlotter): wide_structure = {"x": "@values", "hue": "@columns"} flat_structure = {"x": "@values"} def __init__( self, data=None, variables={}, ): super().__init__(data=data, variables=variables) @property def univariate(self): """Return True if only x or y are used.""" # TODO this could go down to core, but putting it here now. # We'd want to be conceptually clear that univariate only applies # to x/y and not to other semantics, which can exist. # We haven't settled on a good conceptual name for x/y. return bool({"x", "y"} - set(self.variables)) @property def data_variable(self): """Return the variable with data for univariate plots.""" # TODO This could also be in core, but it should have a better name. if not self.univariate: raise AttributeError("This is not a univariate plot") return {"x", "y"}.intersection(self.variables).pop() @property def has_xy_data(self): """Return True at least one of x or y is defined.""" # TODO see above points about where this should go return bool({"x", "y"} & set(self.variables)) def _add_legend( self, ax_obj, artist, fill, element, multiple, alpha, artist_kws, legend_kws, ): """Add artists that reflect semantic mappings and put then in a legend.""" # TODO note that this doesn't handle numeric mappings like the relational plots handles = [] labels = [] for level in self._hue_map.levels: color = self._hue_map(level) kws = self._artist_kws( artist_kws, fill, element, multiple, color, alpha ) # color gets added to the kws to workaround an issue with barplot's color # cycle integration but it causes problems in this context where we are # setting artist properties directly, so pop it off here if "facecolor" in kws: kws.pop("color", None) handles.append(artist(**kws)) labels.append(level) if isinstance(ax_obj, mpl.axes.Axes): ax_obj.legend(handles, labels, title=self.variables["hue"], **legend_kws) else: # i.e. a FacetGrid. TODO make this better legend_data = dict(zip(labels, handles)) ax_obj.add_legend( legend_data, title=self.variables["hue"], label_order=self.var_levels["hue"], **legend_kws ) def _artist_kws(self, kws, fill, element, multiple, color, alpha): """Handle differences between artists in filled/unfilled plots.""" kws = kws.copy() if fill: kws = normalize_kwargs(kws, mpl.collections.PolyCollection) kws.setdefault("facecolor", to_rgba(color, alpha)) if element == "bars": # Make bar() interface with property cycle correctly # https://github.com/matplotlib/matplotlib/issues/19385 kws["color"] = "none" if multiple in ["stack", "fill"] or element == "bars": kws.setdefault("edgecolor", mpl.rcParams["patch.edgecolor"]) else: kws.setdefault("edgecolor", to_rgba(color, 1)) elif element == "bars": kws["facecolor"] = "none" kws["edgecolor"] = to_rgba(color, alpha) else: kws["color"] = to_rgba(color, alpha) return kws def _quantile_to_level(self, data, quantile): """Return data levels corresponding to quantile cuts of mass.""" isoprop = np.asarray(quantile) values = np.ravel(data) sorted_values = np.sort(values)[::-1] normalized_values = np.cumsum(sorted_values) / values.sum() idx = np.searchsorted(normalized_values, 1 - isoprop) levels = np.take(sorted_values, idx, mode="clip") return levels def _cmap_from_color(self, color): """Return a sequential colormap given a color seed.""" # Like so much else here, this is broadly useful, but keeping it # in this class to signify that I haven't thought overly hard about it... r, g, b, _ = to_rgba(color) h, s, _ = husl.rgb_to_husl(r, g, b) xx = np.linspace(-1, 1, int(1.15 * 256))[:256] ramp = np.zeros((256, 3)) ramp[:, 0] = h ramp[:, 1] = s * np.cos(xx) ramp[:, 2] = np.linspace(35, 80, 256) colors = np.clip([husl.husl_to_rgb(*hsl) for hsl in ramp], 0, 1) return mpl.colors.ListedColormap(colors[::-1]) def _default_discrete(self): """Find default values for discrete hist estimation based on variable type.""" if self.univariate: discrete = self.var_types[self.data_variable] == "categorical" else: discrete_x = self.var_types["x"] == "categorical" discrete_y = self.var_types["y"] == "categorical" discrete = discrete_x, discrete_y return discrete def _resolve_multiple(self, curves, multiple): """Modify the density data structure to handle multiple densities.""" # Default baselines have all densities starting at 0 baselines = {k: np.zeros_like(v) for k, v in curves.items()} # TODO we should have some central clearinghouse for checking if any # "grouping" (terminnology?) semantics have been assigned if "hue" not in self.variables: return curves, baselines if multiple in ("stack", "fill"): # Setting stack or fill means that the curves share a # support grid / set of bin edges, so we can make a dataframe # Reverse the column order to plot from top to bottom curves = pd.DataFrame(curves).iloc[:, ::-1] # Find column groups that are nested within col/row variables column_groups = {} for i, keyd in enumerate(map(dict, curves.columns)): facet_key = keyd.get("col", None), keyd.get("row", None) column_groups.setdefault(facet_key, []) column_groups[facet_key].append(i) baselines = curves.copy() for col_idxs in column_groups.values(): cols = curves.columns[col_idxs] norm_constant = curves[cols].sum(axis="columns") # Take the cumulative sum to stack curves[cols] = curves[cols].cumsum(axis="columns") # Normalize by row sum to fill if multiple == "fill": curves[cols] = curves[cols].div(norm_constant, axis="index") # Define where each segment starts baselines[cols] = curves[cols].shift(1, axis=1).fillna(0) if multiple == "dodge": # Account for the unique semantic (non-faceting) levels # This will require rethiniking if we add other semantics! hue_levels = self.var_levels["hue"] n = len(hue_levels) f_fwd, f_inv = self._get_scale_transforms(self.data_variable) for key in curves: level = dict(key)["hue"] hist = curves[key].reset_index(name="heights") level_idx = hue_levels.index(level) a = f_fwd(hist["edges"]) b = f_fwd(hist["edges"] + hist["widths"]) w = (b - a) / n new_min = f_inv(a + level_idx * w) new_max = f_inv(a + (level_idx + 1) * w) hist["widths"] = new_max - new_min hist["edges"] = new_min curves[key] = hist.set_index(["edges", "widths"])["heights"] return curves, baselines # -------------------------------------------------------------------------------- # # Computation # -------------------------------------------------------------------------------- # def _compute_univariate_density( self, data_variable, common_norm, common_grid, estimate_kws, warn_singular=True, ): # Initialize the estimator object estimator = KDE(**estimate_kws) if set(self.variables) - {"x", "y"}: if common_grid: all_observations = self.comp_data.dropna() estimator.define_support(all_observations[data_variable]) else: common_norm = False all_data = self.plot_data.dropna() if common_norm and "weights" in all_data: whole_weight = all_data["weights"].sum() else: whole_weight = len(all_data) densities = {} for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): # Extract the data points from this sub set and remove nulls observations = sub_data[data_variable] # Extract the weights for this subset of observations if "weights" in self.variables: weights = sub_data["weights"] part_weight = weights.sum() else: weights = None part_weight = len(sub_data) # Estimate the density of observations at this level variance = np.nan_to_num(observations.var()) singular = len(observations) < 2 or math.isclose(variance, 0) try: if not singular: # Convoluted approach needed because numerical failures # can manifest in a few different ways. density, support = estimator(observations, weights=weights) except np.linalg.LinAlgError: singular = True if singular: msg = ( "Dataset has 0 variance; skipping density estimate. " "Pass `warn_singular=False` to disable this warning." ) if warn_singular: warnings.warn(msg, UserWarning, stacklevel=4) continue # Invert the scaling of the support points _, f_inv = self._get_scale_transforms(self.data_variable) support = f_inv(support) # Apply a scaling factor so that the integral over all subsets is 1 if common_norm: density *= part_weight / whole_weight # Store the density for this level key = tuple(sub_vars.items()) densities[key] = pd.Series(density, index=support) return densities # -------------------------------------------------------------------------------- # # Plotting # -------------------------------------------------------------------------------- # def plot_univariate_histogram( self, multiple, element, fill, common_norm, common_bins, shrink, kde, kde_kws, color, legend, line_kws, estimate_kws, **plot_kws, ): # -- Default keyword dicts kde_kws = {} if kde_kws is None else kde_kws.copy() line_kws = {} if line_kws is None else line_kws.copy() estimate_kws = {} if estimate_kws is None else estimate_kws.copy() # -- Input checking _check_argument("multiple", ["layer", "stack", "fill", "dodge"], multiple) _check_argument("element", ["bars", "step", "poly"], element) auto_bins_with_weights = ( "weights" in self.variables and estimate_kws["bins"] == "auto" and estimate_kws["binwidth"] is None and not estimate_kws["discrete"] ) if auto_bins_with_weights: msg = ( "`bins` cannot be 'auto' when using weights. " "Setting `bins=10`, but you will likely want to adjust." ) warnings.warn(msg, UserWarning) estimate_kws["bins"] = 10 # Simplify downstream code if we are not normalizing if estimate_kws["stat"] == "count": common_norm = False orient = self.data_variable # Now initialize the Histogram estimator estimator = Hist(**estimate_kws) histograms = {} # Do pre-compute housekeeping related to multiple groups all_data = self.comp_data.dropna() all_weights = all_data.get("weights", None) multiple_histograms = set(self.variables) - {"x", "y"} if multiple_histograms: if common_bins: bin_kws = estimator._define_bin_params(all_data, orient, None) else: common_norm = False if common_norm and all_weights is not None: whole_weight = all_weights.sum() else: whole_weight = len(all_data) # Estimate the smoothed kernel densities, for use later if kde: # TODO alternatively, clip at min/max bins? kde_kws.setdefault("cut", 0) kde_kws["cumulative"] = estimate_kws["cumulative"] densities = self._compute_univariate_density( self.data_variable, common_norm, common_bins, kde_kws, warn_singular=False, ) # First pass through the data to compute the histograms for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): # Prepare the relevant data key = tuple(sub_vars.items()) orient = self.data_variable if "weights" in self.variables: sub_data["weight"] = sub_data.pop("weights") part_weight = sub_data["weight"].sum() else: part_weight = len(sub_data) # Do the histogram computation if not (multiple_histograms and common_bins): bin_kws = estimator._define_bin_params(sub_data, orient, None) res = estimator._normalize(estimator._eval(sub_data, orient, bin_kws)) heights = res[estimator.stat].to_numpy() widths = res["space"].to_numpy() edges = res[orient].to_numpy() - widths / 2 # Rescale the smoothed curve to match the histogram if kde and key in densities: density = densities[key] if estimator.cumulative: hist_norm = heights.max() else: hist_norm = (heights * widths).sum() densities[key] *= hist_norm # Convert edges back to original units for plotting ax = self._get_axes(sub_vars) _, inv = _get_transform_functions(ax, self.data_variable) widths = inv(edges + widths) - inv(edges) edges = inv(edges) # Pack the histogram data and metadata together edges = edges + (1 - shrink) / 2 * widths widths *= shrink index = pd.MultiIndex.from_arrays([ pd.Index(edges, name="edges"), pd.Index(widths, name="widths"), ]) hist = pd.Series(heights, index=index, name="heights") # Apply scaling to normalize across groups if common_norm: hist *= part_weight / whole_weight # Store the finalized histogram data for future plotting histograms[key] = hist # Modify the histogram and density data to resolve multiple groups histograms, baselines = self._resolve_multiple(histograms, multiple) if kde: densities, _ = self._resolve_multiple( densities, None if multiple == "dodge" else multiple ) # Set autoscaling-related meta sticky_stat = (0, 1) if multiple == "fill" else (0, np.inf) if multiple == "fill": # Filled plots should not have any margins bin_vals = histograms.index.to_frame() edges = bin_vals["edges"] widths = bin_vals["widths"] sticky_data = ( edges.min(), edges.max() + widths.loc[edges.idxmax()] ) else: sticky_data = [] # --- Handle default visual attributes # Note: default linewidth is determined after plotting # Default alpha should depend on other parameters if fill: # Note: will need to account for other grouping semantics if added if "hue" in self.variables and multiple == "layer": default_alpha = .5 if element == "bars" else .25 elif kde: default_alpha = .5 else: default_alpha = .75 else: default_alpha = 1 alpha = plot_kws.pop("alpha", default_alpha) # TODO make parameter? hist_artists = [] # Go back through the dataset and draw the plots for sub_vars, _ in self.iter_data("hue", reverse=True): key = tuple(sub_vars.items()) hist = histograms[key].rename("heights").reset_index() bottom = np.asarray(baselines[key]) ax = self._get_axes(sub_vars) # Define the matplotlib attributes that depend on semantic mapping if "hue" in self.variables: sub_color = self._hue_map(sub_vars["hue"]) else: sub_color = color artist_kws = self._artist_kws( plot_kws, fill, element, multiple, sub_color, alpha ) if element == "bars": # Use matplotlib bar plotting plot_func = ax.bar if self.data_variable == "x" else ax.barh artists = plot_func( hist["edges"], hist["heights"] - bottom, hist["widths"], bottom, align="edge", **artist_kws, ) for bar in artists: if self.data_variable == "x": bar.sticky_edges.x[:] = sticky_data bar.sticky_edges.y[:] = sticky_stat else: bar.sticky_edges.x[:] = sticky_stat bar.sticky_edges.y[:] = sticky_data hist_artists.extend(artists) else: # Use either fill_between or plot to draw hull of histogram if element == "step": final = hist.iloc[-1] x = np.append(hist["edges"], final["edges"] + final["widths"]) y = np.append(hist["heights"], final["heights"]) b = np.append(bottom, bottom[-1]) if self.data_variable == "x": step = "post" drawstyle = "steps-post" else: step = "post" # fillbetweenx handles mapping internally drawstyle = "steps-pre" elif element == "poly": x = hist["edges"] + hist["widths"] / 2 y = hist["heights"] b = bottom step = None drawstyle = None if self.data_variable == "x": if fill: artist = ax.fill_between(x, b, y, step=step, **artist_kws) else: artist, = ax.plot(x, y, drawstyle=drawstyle, **artist_kws) artist.sticky_edges.x[:] = sticky_data artist.sticky_edges.y[:] = sticky_stat else: if fill: artist = ax.fill_betweenx(x, b, y, step=step, **artist_kws) else: artist, = ax.plot(y, x, drawstyle=drawstyle, **artist_kws) artist.sticky_edges.x[:] = sticky_stat artist.sticky_edges.y[:] = sticky_data hist_artists.append(artist) if kde: # Add in the density curves try: density = densities[key] except KeyError: continue support = density.index if "x" in self.variables: line_args = support, density sticky_x, sticky_y = None, (0, np.inf) else: line_args = density, support sticky_x, sticky_y = (0, np.inf), None line_kws["color"] = to_rgba(sub_color, 1) line, = ax.plot( *line_args, **line_kws, ) if sticky_x is not None: line.sticky_edges.x[:] = sticky_x if sticky_y is not None: line.sticky_edges.y[:] = sticky_y if element == "bars" and "linewidth" not in plot_kws: # Now we handle linewidth, which depends on the scaling of the plot # We will base everything on the minimum bin width hist_metadata = pd.concat([ # Use .items for generality over dict or df h.index.to_frame() for _, h in histograms.items() ]).reset_index(drop=True) thin_bar_idx = hist_metadata["widths"].idxmin() binwidth = hist_metadata.loc[thin_bar_idx, "widths"] left_edge = hist_metadata.loc[thin_bar_idx, "edges"] # Set initial value default_linewidth = math.inf # Loop through subsets based only on facet variables for sub_vars, _ in self.iter_data(): ax = self._get_axes(sub_vars) # Needed in some cases to get valid transforms. # Innocuous in other cases? ax.autoscale_view() # Convert binwidth from data coordinates to pixels pts_x, pts_y = 72 / ax.figure.dpi * abs( ax.transData.transform([left_edge + binwidth] * 2) - ax.transData.transform([left_edge] * 2) ) if self.data_variable == "x": binwidth_points = pts_x else: binwidth_points = pts_y # The relative size of the lines depends on the appearance # This is a provisional value and may need more tweaking default_linewidth = min(.1 * binwidth_points, default_linewidth) # Set the attributes for bar in hist_artists: # Don't let the lines get too thick max_linewidth = bar.get_linewidth() if not fill: max_linewidth *= 1.5 linewidth = min(default_linewidth, max_linewidth) # If not filling, don't let lines disappear if not fill: min_linewidth = .5 linewidth = max(linewidth, min_linewidth) bar.set_linewidth(linewidth) # --- Finalize the plot ---- # Axis labels ax = self.ax if self.ax is not None else self.facets.axes.flat[0] default_x = default_y = "" if self.data_variable == "x": default_y = estimator.stat.capitalize() if self.data_variable == "y": default_x = estimator.stat.capitalize() self._add_axis_labels(ax, default_x, default_y) # Legend for semantic variables if "hue" in self.variables and legend: if fill or element == "bars": artist = partial(mpl.patches.Patch) else: artist = partial(mpl.lines.Line2D, [], []) ax_obj = self.ax if self.ax is not None else self.facets self._add_legend( ax_obj, artist, fill, element, multiple, alpha, plot_kws, {}, ) def plot_bivariate_histogram( self, common_bins, common_norm, thresh, pthresh, pmax, color, legend, cbar, cbar_ax, cbar_kws, estimate_kws, **plot_kws, ): # Default keyword dicts cbar_kws = {} if cbar_kws is None else cbar_kws.copy() # Now initialize the Histogram estimator estimator = Histogram(**estimate_kws) # Do pre-compute housekeeping related to multiple groups if set(self.variables) - {"x", "y"}: all_data = self.comp_data.dropna() if common_bins: estimator.define_bin_params( all_data["x"], all_data["y"], all_data.get("weights", None), ) else: common_norm = False # -- Determine colormap threshold and norm based on the full data full_heights = [] for _, sub_data in self.iter_data(from_comp_data=True): sub_heights, _ = estimator( sub_data["x"], sub_data["y"], sub_data.get("weights", None) ) full_heights.append(sub_heights) common_color_norm = not set(self.variables) - {"x", "y"} or common_norm if pthresh is not None and common_color_norm: thresh = self._quantile_to_level(full_heights, pthresh) plot_kws.setdefault("vmin", 0) if common_color_norm: if pmax is not None: vmax = self._quantile_to_level(full_heights, pmax) else: vmax = plot_kws.pop("vmax", max(map(np.max, full_heights))) else: vmax = None # Get a default color # (We won't follow the color cycle here, as multiple plots are unlikely) if color is None: color = "C0" # --- Loop over data (subsets) and draw the histograms for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): if sub_data.empty: continue # Do the histogram computation heights, (x_edges, y_edges) = estimator( sub_data["x"], sub_data["y"], weights=sub_data.get("weights", None), ) # Get the axes for this plot ax = self._get_axes(sub_vars) # Invert the scale for the edges _, inv_x = _get_transform_functions(ax, "x") _, inv_y = _get_transform_functions(ax, "y") x_edges = inv_x(x_edges) y_edges = inv_y(y_edges) # Apply scaling to normalize across groups if estimator.stat != "count" and common_norm: heights *= len(sub_data) / len(all_data) # Define the specific kwargs for this artist artist_kws = plot_kws.copy() if "hue" in self.variables: color = self._hue_map(sub_vars["hue"]) cmap = self._cmap_from_color(color) artist_kws["cmap"] = cmap else: cmap = artist_kws.pop("cmap", None) if isinstance(cmap, str): cmap = color_palette(cmap, as_cmap=True) elif cmap is None: cmap = self._cmap_from_color(color) artist_kws["cmap"] = cmap # Set the upper norm on the colormap if not common_color_norm and pmax is not None: vmax = self._quantile_to_level(heights, pmax) if vmax is not None: artist_kws["vmax"] = vmax # Make cells at or below the threshold transparent if not common_color_norm and pthresh: thresh = self._quantile_to_level(heights, pthresh) if thresh is not None: heights = np.ma.masked_less_equal(heights, thresh) # pcolormesh is going to turn the grid off, but we want to keep it # I'm not sure if there's a better way to get the grid state x_grid = any([l.get_visible() for l in ax.xaxis.get_gridlines()]) y_grid = any([l.get_visible() for l in ax.yaxis.get_gridlines()]) mesh = ax.pcolormesh( x_edges, y_edges, heights.T, **artist_kws, ) # pcolormesh sets sticky edges, but we only want them if not thresholding if thresh is not None: mesh.sticky_edges.x[:] = [] mesh.sticky_edges.y[:] = [] # Add an optional colorbar # Note, we want to improve this. When hue is used, it will stack # multiple colorbars with redundant ticks in an ugly way. # But it's going to take some work to have multiple colorbars that # share ticks nicely. if cbar: ax.figure.colorbar(mesh, cbar_ax, ax, **cbar_kws) # Reset the grid state if x_grid: ax.grid(True, axis="x") if y_grid: ax.grid(True, axis="y") # --- Finalize the plot ax = self.ax if self.ax is not None else self.facets.axes.flat[0] self._add_axis_labels(ax) if "hue" in self.variables and legend: # TODO if possible, I would like to move the contour # intensity information into the legend too and label the # iso proportions rather than the raw density values artist_kws = {} artist = partial(mpl.patches.Patch) ax_obj = self.ax if self.ax is not None else self.facets self._add_legend( ax_obj, artist, True, False, "layer", 1, artist_kws, {}, ) def plot_univariate_density( self, multiple, common_norm, common_grid, warn_singular, fill, color, legend, estimate_kws, **plot_kws, ): # Handle conditional defaults if fill is None: fill = multiple in ("stack", "fill") # Preprocess the matplotlib keyword dictionaries if fill: artist = mpl.collections.PolyCollection else: artist = mpl.lines.Line2D plot_kws = normalize_kwargs(plot_kws, artist) # Input checking _check_argument("multiple", ["layer", "stack", "fill"], multiple) # Always share the evaluation grid when stacking subsets = bool(set(self.variables) - {"x", "y"}) if subsets and multiple in ("stack", "fill"): common_grid = True # Do the computation densities = self._compute_univariate_density( self.data_variable, common_norm, common_grid, estimate_kws, warn_singular, ) # Adjust densities based on the `multiple` rule densities, baselines = self._resolve_multiple(densities, multiple) # Control the interaction with autoscaling by defining sticky_edges # i.e. we don't want autoscale margins below the density curve sticky_density = (0, 1) if multiple == "fill" else (0, np.inf) if multiple == "fill": # Filled plots should not have any margins sticky_support = densities.index.min(), densities.index.max() else: sticky_support = [] if fill: if multiple == "layer": default_alpha = .25 else: default_alpha = .75 else: default_alpha = 1 alpha = plot_kws.pop("alpha", default_alpha) # TODO make parameter? # Now iterate through the subsets and draw the densities # We go backwards so stacked densities read from top-to-bottom for sub_vars, _ in self.iter_data("hue", reverse=True): # Extract the support grid and density curve for this level key = tuple(sub_vars.items()) try: density = densities[key] except KeyError: continue support = density.index fill_from = baselines[key] ax = self._get_axes(sub_vars) if "hue" in self.variables: sub_color = self._hue_map(sub_vars["hue"]) else: sub_color = color artist_kws = self._artist_kws( plot_kws, fill, False, multiple, sub_color, alpha ) # Either plot a curve with observation values on the x axis if "x" in self.variables: if fill: artist = ax.fill_between(support, fill_from, density, **artist_kws) else: artist, = ax.plot(support, density, **artist_kws) artist.sticky_edges.x[:] = sticky_support artist.sticky_edges.y[:] = sticky_density # Or plot a curve with observation values on the y axis else: if fill: artist = ax.fill_betweenx(support, fill_from, density, **artist_kws) else: artist, = ax.plot(density, support, **artist_kws) artist.sticky_edges.x[:] = sticky_density artist.sticky_edges.y[:] = sticky_support # --- Finalize the plot ---- ax = self.ax if self.ax is not None else self.facets.axes.flat[0] default_x = default_y = "" if self.data_variable == "x": default_y = "Density" if self.data_variable == "y": default_x = "Density" self._add_axis_labels(ax, default_x, default_y) if "hue" in self.variables and legend: if fill: artist = partial(mpl.patches.Patch) else: artist = partial(mpl.lines.Line2D, [], []) ax_obj = self.ax if self.ax is not None else self.facets self._add_legend( ax_obj, artist, fill, False, multiple, alpha, plot_kws, {}, ) def plot_bivariate_density( self, common_norm, fill, levels, thresh, color, legend, cbar, warn_singular, cbar_ax, cbar_kws, estimate_kws, **contour_kws, ): contour_kws = contour_kws.copy() estimator = KDE(**estimate_kws) if not set(self.variables) - {"x", "y"}: common_norm = False all_data = self.plot_data.dropna() # Loop through the subsets and estimate the KDEs densities, supports = {}, {} for sub_vars, sub_data in self.iter_data("hue", from_comp_data=True): # Extract the data points from this sub set observations = sub_data[["x", "y"]] min_variance = observations.var().fillna(0).min() observations = observations["x"], observations["y"] # Extract the weights for this subset of observations if "weights" in self.variables: weights = sub_data["weights"] else: weights = None # Estimate the density of observations at this level singular = math.isclose(min_variance, 0) try: if not singular: density, support = estimator(*observations, weights=weights) except np.linalg.LinAlgError: # Testing for 0 variance doesn't catch all cases where scipy raises, # but we can also get a ValueError, so we need this convoluted approach singular = True if singular: msg = ( "KDE cannot be estimated (0 variance or perfect covariance). " "Pass `warn_singular=False` to disable this warning." ) if warn_singular: warnings.warn(msg, UserWarning, stacklevel=3) continue # Transform the support grid back to the original scale ax = self._get_axes(sub_vars) _, inv_x = _get_transform_functions(ax, "x") _, inv_y = _get_transform_functions(ax, "y") support = inv_x(support[0]), inv_y(support[1]) # Apply a scaling factor so that the integral over all subsets is 1 if common_norm: density *= len(sub_data) / len(all_data) key = tuple(sub_vars.items()) densities[key] = density supports[key] = support # Define a grid of iso-proportion levels if thresh is None: thresh = 0 if isinstance(levels, Number): levels = np.linspace(thresh, 1, levels) else: if min(levels) < 0 or max(levels) > 1: raise ValueError("levels must be in [0, 1]") # Transform from iso-proportions to iso-densities if common_norm: common_levels = self._quantile_to_level( list(densities.values()), levels, ) draw_levels = {k: common_levels for k in densities} else: draw_levels = { k: self._quantile_to_level(d, levels) for k, d in densities.items() } # Define the coloring of the contours if "hue" in self.variables: for param in ["cmap", "colors"]: if param in contour_kws: msg = f"{param} parameter ignored when using hue mapping." warnings.warn(msg, UserWarning) contour_kws.pop(param) else: # Work out a default coloring of the contours coloring_given = set(contour_kws) & {"cmap", "colors"} if fill and not coloring_given: cmap = self._cmap_from_color(color) contour_kws["cmap"] = cmap if not fill and not coloring_given: contour_kws["colors"] = [color] # Use our internal colormap lookup cmap = contour_kws.pop("cmap", None) if isinstance(cmap, str): cmap = color_palette(cmap, as_cmap=True) if cmap is not None: contour_kws["cmap"] = cmap # Loop through the subsets again and plot the data for sub_vars, _ in self.iter_data("hue"): if "hue" in sub_vars: color = self._hue_map(sub_vars["hue"]) if fill: contour_kws["cmap"] = self._cmap_from_color(color) else: contour_kws["colors"] = [color] ax = self._get_axes(sub_vars) # Choose the function to plot with # TODO could add a pcolormesh based option as well # Which would look something like element="raster" if fill: contour_func = ax.contourf else: contour_func = ax.contour key = tuple(sub_vars.items()) if key not in densities: continue density = densities[key] xx, yy = supports[key] # Pop the label kwarg which is unused by contour_func (but warns) contour_kws.pop("label", None) cset = contour_func( xx, yy, density, levels=draw_levels[key], **contour_kws, ) # Add a color bar representing the contour heights # Note: this shows iso densities, not iso proportions # See more notes in histplot about how this could be improved if cbar: cbar_kws = {} if cbar_kws is None else cbar_kws ax.figure.colorbar(cset, cbar_ax, ax, **cbar_kws) # --- Finalize the plot ax = self.ax if self.ax is not None else self.facets.axes.flat[0] self._add_axis_labels(ax) if "hue" in self.variables and legend: # TODO if possible, I would like to move the contour # intensity information into the legend too and label the # iso proportions rather than the raw density values artist_kws = {} if fill: artist = partial(mpl.patches.Patch) else: artist = partial(mpl.lines.Line2D, [], []) ax_obj = self.ax if self.ax is not None else self.facets self._add_legend( ax_obj, artist, fill, False, "layer", 1, artist_kws, {}, ) def plot_univariate_ecdf(self, estimate_kws, legend, **plot_kws): estimator = ECDF(**estimate_kws) # Set the draw style to step the right way for the data variable drawstyles = dict(x="steps-post", y="steps-pre") plot_kws["drawstyle"] = drawstyles[self.data_variable] # Loop through the subsets, transform and plot the data for sub_vars, sub_data in self.iter_data( "hue", reverse=True, from_comp_data=True, ): # Compute the ECDF if sub_data.empty: continue observations = sub_data[self.data_variable] weights = sub_data.get("weights", None) stat, vals = estimator(observations, weights=weights) # Assign attributes based on semantic mapping artist_kws = plot_kws.copy() if "hue" in self.variables: artist_kws["color"] = self._hue_map(sub_vars["hue"]) # Return the data variable to the linear domain ax = self._get_axes(sub_vars) _, inv = _get_transform_functions(ax, self.data_variable) vals = inv(vals) # Manually set the minimum value on a "log" scale if isinstance(inv.__self__, mpl.scale.LogTransform): vals[0] = -np.inf # Work out the orientation of the plot if self.data_variable == "x": plot_args = vals, stat stat_variable = "y" else: plot_args = stat, vals stat_variable = "x" if estimator.stat == "count": top_edge = len(observations) else: top_edge = 1 # Draw the line for this subset artist, = ax.plot(*plot_args, **artist_kws) sticky_edges = getattr(artist.sticky_edges, stat_variable) sticky_edges[:] = 0, top_edge # --- Finalize the plot ---- ax = self.ax if self.ax is not None else self.facets.axes.flat[0] stat = estimator.stat.capitalize() default_x = default_y = "" if self.data_variable == "x": default_y = stat if self.data_variable == "y": default_x = stat self._add_axis_labels(ax, default_x, default_y) if "hue" in self.variables and legend: artist = partial(mpl.lines.Line2D, [], []) alpha = plot_kws.get("alpha", 1) ax_obj = self.ax if self.ax is not None else self.facets self._add_legend( ax_obj, artist, False, False, None, alpha, plot_kws, {}, ) def plot_rug(self, height, expand_margins, legend, **kws): for sub_vars, sub_data, in self.iter_data(from_comp_data=True): ax = self._get_axes(sub_vars) kws.setdefault("linewidth", 1) if expand_margins: xmarg, ymarg = ax.margins() if "x" in self.variables: ymarg += height * 2 if "y" in self.variables: xmarg += height * 2 ax.margins(x=xmarg, y=ymarg) if "hue" in self.variables: kws.pop("c", None) kws.pop("color", None) if "x" in self.variables: self._plot_single_rug(sub_data, "x", height, ax, kws) if "y" in self.variables: self._plot_single_rug(sub_data, "y", height, ax, kws) # --- Finalize the plot self._add_axis_labels(ax) if "hue" in self.variables and legend: # TODO ideally i'd like the legend artist to look like a rug legend_artist = partial(mpl.lines.Line2D, [], []) self._add_legend( ax, legend_artist, False, False, None, 1, {}, {}, ) def _plot_single_rug(self, sub_data, var, height, ax, kws): """Draw a rugplot along one axis of the plot.""" vector = sub_data[var] n = len(vector) # Return data to linear domain _, inv = _get_transform_functions(ax, var) vector = inv(vector) # We'll always add a single collection with varying colors if "hue" in self.variables: colors = self._hue_map(sub_data["hue"]) else: colors = None # Build the array of values for the LineCollection if var == "x": trans = tx.blended_transform_factory(ax.transData, ax.transAxes) xy_pairs = np.column_stack([ np.repeat(vector, 2), np.tile([0, height], n) ]) if var == "y": trans = tx.blended_transform_factory(ax.transAxes, ax.transData) xy_pairs = np.column_stack([ np.tile([0, height], n), np.repeat(vector, 2) ]) # Draw the lines on the plot line_segs = xy_pairs.reshape([n, 2, 2]) ax.add_collection(LineCollection( line_segs, transform=trans, colors=colors, **kws )) ax.autoscale_view(scalex=var == "x", scaley=var == "y") # ==================================================================================== # # External API # ==================================================================================== # def histplot( data=None, *, # Vector variables x=None, y=None, hue=None, weights=None, # Histogram computation parameters stat="count", bins="auto", binwidth=None, binrange=None, discrete=None, cumulative=False, common_bins=True, common_norm=True, # Histogram appearance parameters multiple="layer", element="bars", fill=True, shrink=1, # Histogram smoothing with a kernel density estimate kde=False, kde_kws=None, line_kws=None, # Bivariate histogram parameters thresh=0, pthresh=None, pmax=None, cbar=False, cbar_ax=None, cbar_kws=None, # Hue mapping parameters palette=None, hue_order=None, hue_norm=None, color=None, # Axes information log_scale=None, legend=True, ax=None, # Other appearance keywords **kwargs, ): p = _DistributionPlotter( data=data, variables=dict(x=x, y=y, hue=hue, weights=weights), ) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) if ax is None: ax = plt.gca() p._attach(ax, log_scale=log_scale) if p.univariate: # Note, bivariate plots won't cycle if fill: method = ax.bar if element == "bars" else ax.fill_between else: method = ax.plot color = _default_color(method, hue, color, kwargs) if not p.has_xy_data: return ax # Default to discrete bins for categorical variables if discrete is None: discrete = p._default_discrete() estimate_kws = dict( stat=stat, bins=bins, binwidth=binwidth, binrange=binrange, discrete=discrete, cumulative=cumulative, ) if p.univariate: p.plot_univariate_histogram( multiple=multiple, element=element, fill=fill, shrink=shrink, common_norm=common_norm, common_bins=common_bins, kde=kde, kde_kws=kde_kws, color=color, legend=legend, estimate_kws=estimate_kws, line_kws=line_kws, **kwargs, ) else: p.plot_bivariate_histogram( common_bins=common_bins, common_norm=common_norm, thresh=thresh, pthresh=pthresh, pmax=pmax, color=color, legend=legend, cbar=cbar, cbar_ax=cbar_ax, cbar_kws=cbar_kws, estimate_kws=estimate_kws, **kwargs, ) return ax histplot.__doc__ = """\ Plot univariate or bivariate histograms to show distributions of datasets. A histogram is a classic visualization tool that represents the distribution of one or more variables by counting the number of observations that fall within discrete bins. This function can normalize the statistic computed within each bin to estimate frequency, density or probability mass, and it can add a smooth curve obtained using a kernel density estimate, similar to :func:`kdeplot`. More information is provided in the :ref:`user guide `. Parameters ---------- {params.core.data} {params.core.xy} {params.core.hue} weights : vector or key in ``data`` If provided, weight the contribution of the corresponding data points towards the count in each bin by these factors. {params.hist.stat} {params.hist.bins} {params.hist.binwidth} {params.hist.binrange} discrete : bool If True, default to ``binwidth=1`` and draw the bars so that they are centered on their corresponding data points. This avoids "gaps" that may otherwise appear when using discrete (integer) data. cumulative : bool If True, plot the cumulative counts as bins increase. common_bins : bool If True, use the same bins when semantic variables produce multiple plots. If using a reference rule to determine the bins, it will be computed with the full dataset. common_norm : bool If True and using a normalized statistic, the normalization will apply over the full dataset. Otherwise, normalize each histogram independently. multiple : {{"layer", "dodge", "stack", "fill"}} Approach to resolving multiple elements when semantic mapping creates subsets. Only relevant with univariate data. element : {{"bars", "step", "poly"}} Visual representation of the histogram statistic. Only relevant with univariate data. fill : bool If True, fill in the space under the histogram. Only relevant with univariate data. shrink : number Scale the width of each bar relative to the binwidth by this factor. Only relevant with univariate data. kde : bool If True, compute a kernel density estimate to smooth the distribution and show on the plot as (one or more) line(s). Only relevant with univariate data. kde_kws : dict Parameters that control the KDE computation, as in :func:`kdeplot`. line_kws : dict Parameters that control the KDE visualization, passed to :meth:`matplotlib.axes.Axes.plot`. thresh : number or None Cells with a statistic less than or equal to this value will be transparent. Only relevant with bivariate data. pthresh : number or None Like ``thresh``, but a value in [0, 1] such that cells with aggregate counts (or other statistics, when used) up to this proportion of the total will be transparent. pmax : number or None A value in [0, 1] that sets that saturation point for the colormap at a value such that cells below constitute this proportion of the total count (or other statistic, when used). {params.dist.cbar} {params.dist.cbar_ax} {params.dist.cbar_kws} {params.core.palette} {params.core.hue_order} {params.core.hue_norm} {params.core.color} {params.dist.log_scale} {params.dist.legend} {params.core.ax} kwargs Other keyword arguments are passed to one of the following matplotlib functions: - :meth:`matplotlib.axes.Axes.bar` (univariate, element="bars") - :meth:`matplotlib.axes.Axes.fill_between` (univariate, other element, fill=True) - :meth:`matplotlib.axes.Axes.plot` (univariate, other element, fill=False) - :meth:`matplotlib.axes.Axes.pcolormesh` (bivariate) Returns ------- {returns.ax} See Also -------- {seealso.displot} {seealso.kdeplot} {seealso.rugplot} {seealso.ecdfplot} {seealso.jointplot} Notes ----- The choice of bins for computing and plotting a histogram can exert substantial influence on the insights that one is able to draw from the visualization. If the bins are too large, they may erase important features. On the other hand, bins that are too small may be dominated by random variability, obscuring the shape of the true underlying distribution. The default bin size is determined using a reference rule that depends on the sample size and variance. This works well in many cases, (i.e., with "well-behaved" data) but it fails in others. It is always a good to try different bin sizes to be sure that you are not missing something important. This function allows you to specify bins in several different ways, such as by setting the total number of bins to use, the width of each bin, or the specific locations where the bins should break. Examples -------- .. include:: ../docstrings/histplot.rst """.format( params=_param_docs, returns=_core_docs["returns"], seealso=_core_docs["seealso"], ) def kdeplot( data=None, *, x=None, y=None, hue=None, weights=None, palette=None, hue_order=None, hue_norm=None, color=None, fill=None, multiple="layer", common_norm=True, common_grid=False, cumulative=False, bw_method="scott", bw_adjust=1, warn_singular=True, log_scale=None, levels=10, thresh=.05, gridsize=200, cut=3, clip=None, legend=True, cbar=False, cbar_ax=None, cbar_kws=None, ax=None, **kwargs, ): # --- Start with backwards compatability for versions < 0.11.0 ---------------- # Handle (past) deprecation of `data2` if "data2" in kwargs: msg = "`data2` has been removed (replaced by `y`); please update your code." raise TypeError(msg) # Handle deprecation of `vertical` vertical = kwargs.pop("vertical", None) if vertical is not None: if vertical: action_taken = "assigning data to `y`." if x is None: data, y = y, data else: x, y = y, x else: action_taken = "assigning data to `x`." msg = textwrap.dedent(f"""\n The `vertical` parameter is deprecated; {action_taken} This will become an error in seaborn v0.14.0; please update your code. """) warnings.warn(msg, UserWarning, stacklevel=2) # Handle deprecation of `bw` bw = kwargs.pop("bw", None) if bw is not None: msg = textwrap.dedent(f"""\n The `bw` parameter is deprecated in favor of `bw_method` and `bw_adjust`. Setting `bw_method={bw}`, but please see the docs for the new parameters and update your code. This will become an error in seaborn v0.14.0. """) warnings.warn(msg, UserWarning, stacklevel=2) bw_method = bw # Handle deprecation of `kernel` if kwargs.pop("kernel", None) is not None: msg = textwrap.dedent("""\n Support for alternate kernels has been removed; using Gaussian kernel. This will become an error in seaborn v0.14.0; please update your code. """) warnings.warn(msg, UserWarning, stacklevel=2) # Handle deprecation of shade_lowest shade_lowest = kwargs.pop("shade_lowest", None) if shade_lowest is not None: if shade_lowest: thresh = 0 msg = textwrap.dedent(f"""\n `shade_lowest` has been replaced by `thresh`; setting `thresh={thresh}. This will become an error in seaborn v0.14.0; please update your code. """) warnings.warn(msg, UserWarning, stacklevel=2) # Handle "soft" deprecation of shade `shade` is not really the right # terminology here, but unlike some of the other deprecated parameters it # is probably very commonly used and much hard to remove. This is therefore # going to be a longer process where, first, `fill` will be introduced and # be used throughout the documentation. In 0.12, when kwarg-only # enforcement hits, we can remove the shade/shade_lowest out of the # function signature all together and pull them out of the kwargs. Then we # can actually fire a FutureWarning, and eventually remove. shade = kwargs.pop("shade", None) if shade is not None: fill = shade msg = textwrap.dedent(f"""\n `shade` is now deprecated in favor of `fill`; setting `fill={shade}`. This will become an error in seaborn v0.14.0; please update your code. """) warnings.warn(msg, FutureWarning, stacklevel=2) # Handle `n_levels` # This was never in the formal API but it was processed, and appeared in an # example. We can treat as an alias for `levels` now and deprecate later. levels = kwargs.pop("n_levels", levels) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # p = _DistributionPlotter( data=data, variables=dict(x=x, y=y, hue=hue, weights=weights), ) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) if ax is None: ax = plt.gca() p._attach(ax, allowed_types=["numeric", "datetime"], log_scale=log_scale) method = ax.fill_between if fill else ax.plot color = _default_color(method, hue, color, kwargs) if not p.has_xy_data: return ax # Pack the kwargs for statistics.KDE estimate_kws = dict( bw_method=bw_method, bw_adjust=bw_adjust, gridsize=gridsize, cut=cut, clip=clip, cumulative=cumulative, ) if p.univariate: plot_kws = kwargs.copy() p.plot_univariate_density( multiple=multiple, common_norm=common_norm, common_grid=common_grid, fill=fill, color=color, legend=legend, warn_singular=warn_singular, estimate_kws=estimate_kws, **plot_kws, ) else: p.plot_bivariate_density( common_norm=common_norm, fill=fill, levels=levels, thresh=thresh, legend=legend, color=color, warn_singular=warn_singular, cbar=cbar, cbar_ax=cbar_ax, cbar_kws=cbar_kws, estimate_kws=estimate_kws, **kwargs, ) return ax kdeplot.__doc__ = """\ Plot univariate or bivariate distributions using kernel density estimation. A kernel density estimate (KDE) plot is a method for visualizing the distribution of observations in a dataset, analogous to a histogram. KDE represents the data using a continuous probability density curve in one or more dimensions. The approach is explained further in the :ref:`user guide `. Relative to a histogram, KDE can produce a plot that is less cluttered and more interpretable, especially when drawing multiple distributions. But it has the potential to introduce distortions if the underlying distribution is bounded or not smooth. Like a histogram, the quality of the representation also depends on the selection of good smoothing parameters. Parameters ---------- {params.core.data} {params.core.xy} {params.core.hue} weights : vector or key in ``data`` If provided, weight the kernel density estimation using these values. {params.core.palette} {params.core.hue_order} {params.core.hue_norm} {params.core.color} fill : bool or None If True, fill in the area under univariate density curves or between bivariate contours. If None, the default depends on ``multiple``. {params.dist.multiple} common_norm : bool If True, scale each conditional density by the number of observations such that the total area under all densities sums to 1. Otherwise, normalize each density independently. common_grid : bool If True, use the same evaluation grid for each kernel density estimate. Only relevant with univariate data. {params.kde.cumulative} {params.kde.bw_method} {params.kde.bw_adjust} warn_singular : bool If True, issue a warning when trying to estimate the density of data with zero variance. {params.dist.log_scale} levels : int or vector Number of contour levels or values to draw contours at. A vector argument must have increasing values in [0, 1]. Levels correspond to iso-proportions of the density: e.g., 20% of the probability mass will lie below the contour drawn for 0.2. Only relevant with bivariate data. thresh : number in [0, 1] Lowest iso-proportion level at which to draw a contour line. Ignored when ``levels`` is a vector. Only relevant with bivariate data. gridsize : int Number of points on each dimension of the evaluation grid. {params.kde.cut} {params.kde.clip} {params.dist.legend} {params.dist.cbar} {params.dist.cbar_ax} {params.dist.cbar_kws} {params.core.ax} kwargs Other keyword arguments are passed to one of the following matplotlib functions: - :meth:`matplotlib.axes.Axes.plot` (univariate, ``fill=False``), - :meth:`matplotlib.axes.Axes.fill_between` (univariate, ``fill=True``), - :meth:`matplotlib.axes.Axes.contour` (bivariate, ``fill=False``), - :meth:`matplotlib.axes.contourf` (bivariate, ``fill=True``). Returns ------- {returns.ax} See Also -------- {seealso.displot} {seealso.histplot} {seealso.ecdfplot} {seealso.jointplot} {seealso.violinplot} Notes ----- The *bandwidth*, or standard deviation of the smoothing kernel, is an important parameter. Misspecification of the bandwidth can produce a distorted representation of the data. Much like the choice of bin width in a histogram, an over-smoothed curve can erase true features of a distribution, while an under-smoothed curve can create false features out of random variability. The rule-of-thumb that sets the default bandwidth works best when the true distribution is smooth, unimodal, and roughly bell-shaped. It is always a good idea to check the default behavior by using ``bw_adjust`` to increase or decrease the amount of smoothing. Because the smoothing algorithm uses a Gaussian kernel, the estimated density curve can extend to values that do not make sense for a particular dataset. For example, the curve may be drawn over negative values when smoothing data that are naturally positive. The ``cut`` and ``clip`` parameters can be used to control the extent of the curve, but datasets that have many observations close to a natural boundary may be better served by a different visualization method. Similar considerations apply when a dataset is naturally discrete or "spiky" (containing many repeated observations of the same value). Kernel density estimation will always produce a smooth curve, which would be misleading in these situations. The units on the density axis are a common source of confusion. While kernel density estimation produces a probability distribution, the height of the curve at each point gives a density, not a probability. A probability can be obtained only by integrating the density across a range. The curve is normalized so that the integral over all possible values is 1, meaning that the scale of the density axis depends on the data values. Examples -------- .. include:: ../docstrings/kdeplot.rst """.format( params=_param_docs, returns=_core_docs["returns"], seealso=_core_docs["seealso"], ) def ecdfplot( data=None, *, # Vector variables x=None, y=None, hue=None, weights=None, # Computation parameters stat="proportion", complementary=False, # Hue mapping parameters palette=None, hue_order=None, hue_norm=None, # Axes information log_scale=None, legend=True, ax=None, # Other appearance keywords **kwargs, ): p = _DistributionPlotter( data=data, variables=dict(x=x, y=y, hue=hue, weights=weights), ) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) # We could support other semantics (size, style) here fairly easily # But it would make distplot a bit more complicated. # It's always possible to add features like that later, so I am going to defer. # It will be even easier to wait until after there is a more general/abstract # way to go from semantic specs to artist attributes. if ax is None: ax = plt.gca() p._attach(ax, log_scale=log_scale) color = kwargs.pop("color", kwargs.pop("c", None)) kwargs["color"] = _default_color(ax.plot, hue, color, kwargs) if not p.has_xy_data: return ax # We could add this one day, but it's of dubious value if not p.univariate: raise NotImplementedError("Bivariate ECDF plots are not implemented") estimate_kws = dict( stat=stat, complementary=complementary, ) p.plot_univariate_ecdf( estimate_kws=estimate_kws, legend=legend, **kwargs, ) return ax ecdfplot.__doc__ = """\ Plot empirical cumulative distribution functions. An ECDF represents the proportion or count of observations falling below each unique value in a dataset. Compared to a histogram or density plot, it has the advantage that each observation is visualized directly, meaning that there are no binning or smoothing parameters that need to be adjusted. It also aids direct comparisons between multiple distributions. A downside is that the relationship between the appearance of the plot and the basic properties of the distribution (such as its central tendency, variance, and the presence of any bimodality) may not be as intuitive. More information is provided in the :ref:`user guide `. Parameters ---------- {params.core.data} {params.core.xy} {params.core.hue} weights : vector or key in ``data`` If provided, weight the contribution of the corresponding data points towards the cumulative distribution using these values. {params.ecdf.stat} {params.ecdf.complementary} {params.core.palette} {params.core.hue_order} {params.core.hue_norm} {params.dist.log_scale} {params.dist.legend} {params.core.ax} kwargs Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.plot`. Returns ------- {returns.ax} See Also -------- {seealso.displot} {seealso.histplot} {seealso.kdeplot} {seealso.rugplot} Examples -------- .. include:: ../docstrings/ecdfplot.rst """.format( params=_param_docs, returns=_core_docs["returns"], seealso=_core_docs["seealso"], ) def rugplot( data=None, *, x=None, y=None, hue=None, height=.025, expand_margins=True, palette=None, hue_order=None, hue_norm=None, legend=True, ax=None, **kwargs ): # A note: I think it would make sense to add multiple= to rugplot and allow # rugs for different hue variables to be shifted orthogonal to the data axis # But is this stacking, or dodging? # A note: if we want to add a style semantic to rugplot, # we could make an option that draws the rug using scatterplot # A note, it would also be nice to offer some kind of histogram/density # rugplot, since alpha blending doesn't work great in the large n regime # --- Start with backwards compatability for versions < 0.11.0 ---------------- a = kwargs.pop("a", None) axis = kwargs.pop("axis", None) if a is not None: data = a msg = textwrap.dedent("""\n The `a` parameter has been replaced; use `x`, `y`, and/or `data` instead. Please update your code; This will become an error in seaborn v0.14.0. """) warnings.warn(msg, UserWarning, stacklevel=2) if axis is not None: if axis == "x": x = data elif axis == "y": y = data data = None msg = textwrap.dedent(f"""\n The `axis` parameter has been deprecated; use the `{axis}` parameter instead. Please update your code; this will become an error in seaborn v0.14.0. """) warnings.warn(msg, UserWarning, stacklevel=2) vertical = kwargs.pop("vertical", None) if vertical is not None: if vertical: action_taken = "assigning data to `y`." if x is None: data, y = y, data else: x, y = y, x else: action_taken = "assigning data to `x`." msg = textwrap.dedent(f"""\n The `vertical` parameter is deprecated; {action_taken} This will become an error in seaborn v0.14.0; please update your code. """) warnings.warn(msg, UserWarning, stacklevel=2) # - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - # p = _DistributionPlotter( data=data, variables=dict(x=x, y=y, hue=hue), ) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) if ax is None: ax = plt.gca() p._attach(ax) color = kwargs.pop("color", kwargs.pop("c", None)) kwargs["color"] = _default_color(ax.plot, hue, color, kwargs) if not p.has_xy_data: return ax p.plot_rug(height, expand_margins, legend, **kwargs) return ax rugplot.__doc__ = """\ Plot marginal distributions by drawing ticks along the x and y axes. This function is intended to complement other plots by showing the location of individual observations in an unobtrusive way. Parameters ---------- {params.core.data} {params.core.xy} {params.core.hue} height : float Proportion of axes extent covered by each rug element. Can be negative. expand_margins : bool If True, increase the axes margins by the height of the rug to avoid overlap with other elements. {params.core.palette} {params.core.hue_order} {params.core.hue_norm} legend : bool If False, do not add a legend for semantic variables. {params.core.ax} kwargs Other keyword arguments are passed to :meth:`matplotlib.collections.LineCollection` Returns ------- {returns.ax} Examples -------- .. include:: ../docstrings/rugplot.rst """.format( params=_param_docs, returns=_core_docs["returns"], ) def displot( data=None, *, # Vector variables x=None, y=None, hue=None, row=None, col=None, weights=None, # Other plot parameters kind="hist", rug=False, rug_kws=None, log_scale=None, legend=True, # Hue-mapping parameters palette=None, hue_order=None, hue_norm=None, color=None, # Faceting parameters col_wrap=None, row_order=None, col_order=None, height=5, aspect=1, facet_kws=None, **kwargs, ): p = _DistributionPlotter( data=data, variables=dict(x=x, y=y, hue=hue, weights=weights, row=row, col=col), ) p.map_hue(palette=palette, order=hue_order, norm=hue_norm) _check_argument("kind", ["hist", "kde", "ecdf"], kind) # --- Initialize the FacetGrid object # Check for attempt to plot onto specific axes and warn if "ax" in kwargs: msg = ( "`displot` is a figure-level function and does not accept " "the ax= parameter. You may wish to try {}plot.".format(kind) ) warnings.warn(msg, UserWarning) kwargs.pop("ax") for var in ["row", "col"]: # Handle faceting variables that lack name information if var in p.variables and p.variables[var] is None: p.variables[var] = f"_{var}_" # Adapt the plot_data dataframe for use with FacetGrid grid_data = p.plot_data.rename(columns=p.variables) grid_data = grid_data.loc[:, ~grid_data.columns.duplicated()] col_name = p.variables.get("col") row_name = p.variables.get("row") if facet_kws is None: facet_kws = {} g = FacetGrid( data=grid_data, row=row_name, col=col_name, col_wrap=col_wrap, row_order=row_order, col_order=col_order, height=height, aspect=aspect, **facet_kws, ) # Now attach the axes object to the plotter object if kind == "kde": allowed_types = ["numeric", "datetime"] else: allowed_types = None p._attach(g, allowed_types=allowed_types, log_scale=log_scale) # Check for a specification that lacks x/y data and return early if not p.has_xy_data: return g if color is None and hue is None: color = "C0" # XXX else warn if hue is not None? kwargs["legend"] = legend # --- Draw the plots if kind == "hist": hist_kws = kwargs.copy() # Extract the parameters that will go directly to Histogram estimate_defaults = {} _assign_default_kwargs(estimate_defaults, Histogram.__init__, histplot) estimate_kws = {} for key, default_val in estimate_defaults.items(): estimate_kws[key] = hist_kws.pop(key, default_val) # Handle derivative defaults if estimate_kws["discrete"] is None: estimate_kws["discrete"] = p._default_discrete() hist_kws["estimate_kws"] = estimate_kws hist_kws.setdefault("color", color) if p.univariate: _assign_default_kwargs(hist_kws, p.plot_univariate_histogram, histplot) p.plot_univariate_histogram(**hist_kws) else: _assign_default_kwargs(hist_kws, p.plot_bivariate_histogram, histplot) p.plot_bivariate_histogram(**hist_kws) elif kind == "kde": kde_kws = kwargs.copy() # Extract the parameters that will go directly to KDE estimate_defaults = {} _assign_default_kwargs(estimate_defaults, KDE.__init__, kdeplot) estimate_kws = {} for key, default_val in estimate_defaults.items(): estimate_kws[key] = kde_kws.pop(key, default_val) kde_kws["estimate_kws"] = estimate_kws kde_kws["color"] = color if p.univariate: _assign_default_kwargs(kde_kws, p.plot_univariate_density, kdeplot) p.plot_univariate_density(**kde_kws) else: _assign_default_kwargs(kde_kws, p.plot_bivariate_density, kdeplot) p.plot_bivariate_density(**kde_kws) elif kind == "ecdf": ecdf_kws = kwargs.copy() # Extract the parameters that will go directly to the estimator estimate_kws = {} estimate_defaults = {} _assign_default_kwargs(estimate_defaults, ECDF.__init__, ecdfplot) for key, default_val in estimate_defaults.items(): estimate_kws[key] = ecdf_kws.pop(key, default_val) ecdf_kws["estimate_kws"] = estimate_kws ecdf_kws["color"] = color if p.univariate: _assign_default_kwargs(ecdf_kws, p.plot_univariate_ecdf, ecdfplot) p.plot_univariate_ecdf(**ecdf_kws) else: raise NotImplementedError("Bivariate ECDF plots are not implemented") # All plot kinds can include a rug if rug: # TODO with expand_margins=True, each facet expands margins... annoying! if rug_kws is None: rug_kws = {} _assign_default_kwargs(rug_kws, p.plot_rug, rugplot) rug_kws["legend"] = False if color is not None: rug_kws["color"] = color p.plot_rug(**rug_kws) # Call FacetGrid annotation methods # Note that the legend is currently set inside the plotting method g.set_axis_labels( x_var=p.variables.get("x", g.axes.flat[0].get_xlabel()), y_var=p.variables.get("y", g.axes.flat[0].get_ylabel()), ) g.set_titles() g.tight_layout() if data is not None and (x is not None or y is not None): if not isinstance(data, pd.DataFrame): data = pd.DataFrame(data) g.data = pd.merge( data, g.data[g.data.columns.difference(data.columns)], left_index=True, right_index=True, ) else: wide_cols = { k: f"_{k}_" if v is None else v for k, v in p.variables.items() } g.data = p.plot_data.rename(columns=wide_cols) return g displot.__doc__ = """\ Figure-level interface for drawing distribution plots onto a FacetGrid. This function provides access to several approaches for visualizing the univariate or bivariate distribution of data, including subsets of data defined by semantic mapping and faceting across multiple subplots. The ``kind`` parameter selects the approach to use: - :func:`histplot` (with ``kind="hist"``; the default) - :func:`kdeplot` (with ``kind="kde"``) - :func:`ecdfplot` (with ``kind="ecdf"``; univariate-only) Additionally, a :func:`rugplot` can be added to any kind of plot to show individual observations. Extra keyword arguments are passed to the underlying function, so you should refer to the documentation for each to understand the complete set of options for making plots with this interface. See the :doc:`distribution plots tutorial <../tutorial/distributions>` for a more in-depth discussion of the relative strengths and weaknesses of each approach. The distinction between figure-level and axes-level functions is explained further in the :doc:`user guide <../tutorial/function_overview>`. Parameters ---------- {params.core.data} {params.core.xy} {params.core.hue} {params.facets.rowcol} weights : vector or key in ``data`` Observation weights used for computing the distribution function. kind : {{"hist", "kde", "ecdf"}} Approach for visualizing the data. Selects the underlying plotting function and determines the additional set of valid parameters. rug : bool If True, show each observation with marginal ticks (as in :func:`rugplot`). rug_kws : dict Parameters to control the appearance of the rug plot. {params.dist.log_scale} {params.dist.legend} {params.core.palette} {params.core.hue_order} {params.core.hue_norm} {params.core.color} {params.facets.col_wrap} {params.facets.rowcol_order} {params.facets.height} {params.facets.aspect} {params.facets.facet_kws} kwargs Other keyword arguments are documented with the relevant axes-level function: - :func:`histplot` (with ``kind="hist"``) - :func:`kdeplot` (with ``kind="kde"``) - :func:`ecdfplot` (with ``kind="ecdf"``) Returns ------- {returns.facetgrid} See Also -------- {seealso.histplot} {seealso.kdeplot} {seealso.rugplot} {seealso.ecdfplot} {seealso.jointplot} Examples -------- See the API documentation for the axes-level functions for more details about the breadth of options available for each plot kind. .. include:: ../docstrings/displot.rst """.format( params=_param_docs, returns=_core_docs["returns"], seealso=_core_docs["seealso"], ) # =========================================================================== # # DEPRECATED FUNCTIONS LIVE BELOW HERE # =========================================================================== # def _freedman_diaconis_bins(a): """Calculate number of hist bins using Freedman-Diaconis rule.""" # From https://stats.stackexchange.com/questions/798/ a = np.asarray(a) if len(a) < 2: return 1 iqr = np.subtract.reduce(np.nanpercentile(a, [75, 25])) h = 2 * iqr / (len(a) ** (1 / 3)) # fall back to sqrt(a) bins if iqr is 0 if h == 0: return int(np.sqrt(a.size)) else: return int(np.ceil((a.max() - a.min()) / h)) def distplot(a=None, bins=None, hist=True, kde=True, rug=False, fit=None, hist_kws=None, kde_kws=None, rug_kws=None, fit_kws=None, color=None, vertical=False, norm_hist=False, axlabel=None, label=None, ax=None, x=None): """ DEPRECATED This function has been deprecated and will be removed in seaborn v0.14.0. It has been replaced by :func:`histplot` and :func:`displot`, two functions with a modern API and many more capabilities. For a guide to updating, please see this notebook: https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751 """ if kde and not hist: axes_level_suggestion = ( "`kdeplot` (an axes-level function for kernel density plots)" ) else: axes_level_suggestion = ( "`histplot` (an axes-level function for histograms)" ) msg = textwrap.dedent(f""" `distplot` is a deprecated function and will be removed in seaborn v0.14.0. Please adapt your code to use either `displot` (a figure-level function with similar flexibility) or {axes_level_suggestion}. For a guide to updating your code to use the new functions, please see https://gist.github.com/mwaskom/de44147ed2974457ad6372750bbe5751 """) warnings.warn(msg, UserWarning, stacklevel=2) if ax is None: ax = plt.gca() # Intelligently label the support axis label_ax = bool(axlabel) if axlabel is None and hasattr(a, "name"): axlabel = a.name if axlabel is not None: label_ax = True # Support new-style API if x is not None: a = x # Make a a 1-d float array a = np.asarray(a, float) if a.ndim > 1: a = a.squeeze() # Drop null values from array a = remove_na(a) # Decide if the hist is normed norm_hist = norm_hist or kde or (fit is not None) # Handle dictionary defaults hist_kws = {} if hist_kws is None else hist_kws.copy() kde_kws = {} if kde_kws is None else kde_kws.copy() rug_kws = {} if rug_kws is None else rug_kws.copy() fit_kws = {} if fit_kws is None else fit_kws.copy() # Get the color from the current color cycle if color is None: if vertical: line, = ax.plot(0, a.mean()) else: line, = ax.plot(a.mean(), 0) color = line.get_color() line.remove() # Plug the label into the right kwarg dictionary if label is not None: if hist: hist_kws["label"] = label elif kde: kde_kws["label"] = label elif rug: rug_kws["label"] = label elif fit: fit_kws["label"] = label if hist: if bins is None: bins = min(_freedman_diaconis_bins(a), 50) hist_kws.setdefault("alpha", 0.4) hist_kws.setdefault("density", norm_hist) orientation = "horizontal" if vertical else "vertical" hist_color = hist_kws.pop("color", color) ax.hist(a, bins, orientation=orientation, color=hist_color, **hist_kws) if hist_color != color: hist_kws["color"] = hist_color axis = "y" if vertical else "x" if kde: kde_color = kde_kws.pop("color", color) kdeplot(**{axis: a}, ax=ax, color=kde_color, **kde_kws) if kde_color != color: kde_kws["color"] = kde_color if rug: rug_color = rug_kws.pop("color", color) rugplot(**{axis: a}, ax=ax, color=rug_color, **rug_kws) if rug_color != color: rug_kws["color"] = rug_color if fit is not None: def pdf(x): return fit.pdf(x, *params) fit_color = fit_kws.pop("color", "#282828") gridsize = fit_kws.pop("gridsize", 200) cut = fit_kws.pop("cut", 3) clip = fit_kws.pop("clip", (-np.inf, np.inf)) bw = gaussian_kde(a).scotts_factor() * a.std(ddof=1) x = _kde_support(a, bw, gridsize, cut, clip) params = fit.fit(a) y = pdf(x) if vertical: x, y = y, x ax.plot(x, y, color=fit_color, **fit_kws) if fit_color != "#282828": fit_kws["color"] = fit_color if label_ax: if vertical: ax.set_ylabel(axlabel) else: ax.set_xlabel(axlabel) return ax ================================================ FILE: seaborn/external/__init__.py ================================================ ================================================ FILE: seaborn/external/appdirs.py ================================================ #!/usr/bin/env python3 # Copyright (c) 2005-2010 ActiveState Software Inc. # Copyright (c) 2013 Eddy Petrișor # flake8: noqa """ This file is directly from https://github.com/ActiveState/appdirs/blob/3fe6a83776843a46f20c2e5587afcffe05e03b39/appdirs.py The license of https://github.com/ActiveState/appdirs copied below: # This is the MIT license Copyright (c) 2010 ActiveState Software Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ """Utilities for determining application-specific dirs. See for details and usage. """ # Dev Notes: # - MSDN on where to store app data files: # http://support.microsoft.com/default.aspx?scid=kb;en-us;310294#XSLTH3194121123120121120120 # - Mac OS X: http://developer.apple.com/documentation/MacOSX/Conceptual/BPFileSystem/index.html # - XDG spec for Un*x: https://standards.freedesktop.org/basedir-spec/basedir-spec-latest.html __version__ = "1.4.4" __version_info__ = tuple(int(segment) for segment in __version__.split(".")) import sys import os unicode = str if sys.platform.startswith('java'): import platform os_name = platform.java_ver()[3][0] if os_name.startswith('Windows'): # "Windows XP", "Windows 7", etc. system = 'win32' elif os_name.startswith('Mac'): # "Mac OS X", etc. system = 'darwin' else: # "Linux", "SunOS", "FreeBSD", etc. # Setting this to "linux2" is not ideal, but only Windows or Mac # are actually checked for and the rest of the module expects # *sys.platform* style strings. system = 'linux2' else: system = sys.platform def user_cache_dir(appname=None, appauthor=None, version=None, opinion=True): r"""Return full path to the user-specific cache dir for this application. "appname" is the name of application. If None, just the system directory is returned. "appauthor" (only used on Windows) is the name of the appauthor or distributing body for this application. Typically it is the owning company name. This falls back to appname. You may pass False to disable it. "version" is an optional version path element to append to the path. You might want to use this if you want multiple versions of your app to be able to run independently. If used, this would typically be ".". Only applied when appname is present. "opinion" (boolean) can be False to disable the appending of "Cache" to the base app data dir for Windows. See discussion below. Typical user cache directories are: Mac OS X: ~/Library/Caches/ Unix: ~/.cache/ (XDG default) Win XP: C:\Documents and Settings\\Local Settings\Application Data\\\Cache Vista: C:\Users\\AppData\Local\\\Cache On Windows the only suggestion in the MSDN docs is that local settings go in the `CSIDL_LOCAL_APPDATA` directory. This is identical to the non-roaming app data dir (the default returned by `user_data_dir` above). Apps typically put cache data somewhere *under* the given dir here. Some examples: ...\Mozilla\Firefox\Profiles\\Cache ...\Acme\SuperApp\Cache\1.0 OPINION: This function appends "Cache" to the `CSIDL_LOCAL_APPDATA` value. This can be disabled with the `opinion=False` option. """ if system == "win32": if appauthor is None: appauthor = appname path = os.path.normpath(_get_win_folder("CSIDL_LOCAL_APPDATA")) if appname: if appauthor is not False: path = os.path.join(path, appauthor, appname) else: path = os.path.join(path, appname) if opinion: path = os.path.join(path, "Cache") elif system == 'darwin': path = os.path.expanduser('~/Library/Caches') if appname: path = os.path.join(path, appname) else: path = os.getenv('XDG_CACHE_HOME', os.path.expanduser('~/.cache')) if appname: path = os.path.join(path, appname) if appname and version: path = os.path.join(path, version) return path #---- internal support stuff def _get_win_folder_from_registry(csidl_name): """This is a fallback technique at best. I'm not sure if using the registry for this guarantees us the correct answer for all CSIDL_* names. """ import winreg as _winreg shell_folder_name = { "CSIDL_APPDATA": "AppData", "CSIDL_COMMON_APPDATA": "Common AppData", "CSIDL_LOCAL_APPDATA": "Local AppData", }[csidl_name] key = _winreg.OpenKey( _winreg.HKEY_CURRENT_USER, r"Software\Microsoft\Windows\CurrentVersion\Explorer\Shell Folders" ) dir, type = _winreg.QueryValueEx(key, shell_folder_name) return dir def _get_win_folder_with_pywin32(csidl_name): from win32com.shell import shellcon, shell dir = shell.SHGetFolderPath(0, getattr(shellcon, csidl_name), 0, 0) # Try to make this a unicode path because SHGetFolderPath does # not return unicode strings when there is unicode data in the # path. try: dir = unicode(dir) # Downgrade to short path name if have highbit chars. See # . has_high_char = False for c in dir: if ord(c) > 255: has_high_char = True break if has_high_char: try: import win32api dir = win32api.GetShortPathName(dir) except ImportError: pass except UnicodeError: pass return dir def _get_win_folder_with_ctypes(csidl_name): import ctypes csidl_const = { "CSIDL_APPDATA": 26, "CSIDL_COMMON_APPDATA": 35, "CSIDL_LOCAL_APPDATA": 28, }[csidl_name] buf = ctypes.create_unicode_buffer(1024) ctypes.windll.shell32.SHGetFolderPathW(None, csidl_const, None, 0, buf) # Downgrade to short path name if have highbit chars. See # . has_high_char = False for c in buf: if ord(c) > 255: has_high_char = True break if has_high_char: buf2 = ctypes.create_unicode_buffer(1024) if ctypes.windll.kernel32.GetShortPathNameW(buf.value, buf2, 1024): buf = buf2 return buf.value def _get_win_folder_with_jna(csidl_name): import array from com.sun import jna from com.sun.jna.platform import win32 buf_size = win32.WinDef.MAX_PATH * 2 buf = array.zeros('c', buf_size) shell = win32.Shell32.INSTANCE shell.SHGetFolderPath(None, getattr(win32.ShlObj, csidl_name), None, win32.ShlObj.SHGFP_TYPE_CURRENT, buf) dir = jna.Native.toString(buf.tostring()).rstrip("\0") # Downgrade to short path name if have highbit chars. See # . has_high_char = False for c in dir: if ord(c) > 255: has_high_char = True break if has_high_char: buf = array.zeros('c', buf_size) kernel = win32.Kernel32.INSTANCE if kernel.GetShortPathName(dir, buf, buf_size): dir = jna.Native.toString(buf.tostring()).rstrip("\0") return dir if system == "win32": try: import win32com.shell _get_win_folder = _get_win_folder_with_pywin32 except ImportError: try: from ctypes import windll _get_win_folder = _get_win_folder_with_ctypes except ImportError: try: import com.sun.jna _get_win_folder = _get_win_folder_with_jna except ImportError: _get_win_folder = _get_win_folder_from_registry ================================================ FILE: seaborn/external/docscrape.py ================================================ """Extract reference documentation from the NumPy source tree. Copyright (C) 2008 Stefan van der Walt , Pauli Virtanen Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ import inspect import textwrap import re import pydoc from warnings import warn from collections import namedtuple from collections.abc import Callable, Mapping import copy import sys def strip_blank_lines(l): "Remove leading and trailing blank lines from a list of lines" while l and not l[0].strip(): del l[0] while l and not l[-1].strip(): del l[-1] return l class Reader: """A line-based string reader. """ def __init__(self, data): """ Parameters ---------- data : str String with lines separated by '\n'. """ if isinstance(data, list): self._str = data else: self._str = data.split('\n') # store string as list of lines self.reset() def __getitem__(self, n): return self._str[n] def reset(self): self._l = 0 # current line nr def read(self): if not self.eof(): out = self[self._l] self._l += 1 return out else: return '' def seek_next_non_empty_line(self): for l in self[self._l:]: if l.strip(): break else: self._l += 1 def eof(self): return self._l >= len(self._str) def read_to_condition(self, condition_func): start = self._l for line in self[start:]: if condition_func(line): return self[start:self._l] self._l += 1 if self.eof(): return self[start:self._l+1] return [] def read_to_next_empty_line(self): self.seek_next_non_empty_line() def is_empty(line): return not line.strip() return self.read_to_condition(is_empty) def read_to_next_unindented_line(self): def is_unindented(line): return (line.strip() and (len(line.lstrip()) == len(line))) return self.read_to_condition(is_unindented) def peek(self, n=0): if self._l + n < len(self._str): return self[self._l + n] else: return '' def is_empty(self): return not ''.join(self._str).strip() class ParseError(Exception): def __str__(self): message = self.args[0] if hasattr(self, 'docstring'): message = f"{message} in {self.docstring!r}" return message Parameter = namedtuple('Parameter', ['name', 'type', 'desc']) class NumpyDocString(Mapping): """Parses a numpydoc string to an abstract representation Instances define a mapping from section title to structured data. """ sections = { 'Signature': '', 'Summary': [''], 'Extended Summary': [], 'Parameters': [], 'Returns': [], 'Yields': [], 'Receives': [], 'Raises': [], 'Warns': [], 'Other Parameters': [], 'Attributes': [], 'Methods': [], 'See Also': [], 'Notes': [], 'Warnings': [], 'References': '', 'Examples': '', 'index': {} } def __init__(self, docstring, config={}): orig_docstring = docstring docstring = textwrap.dedent(docstring).split('\n') self._doc = Reader(docstring) self._parsed_data = copy.deepcopy(self.sections) try: self._parse() except ParseError as e: e.docstring = orig_docstring raise def __getitem__(self, key): return self._parsed_data[key] def __setitem__(self, key, val): if key not in self._parsed_data: self._error_location(f"Unknown section {key}", error=False) else: self._parsed_data[key] = val def __iter__(self): return iter(self._parsed_data) def __len__(self): return len(self._parsed_data) def _is_at_section(self): self._doc.seek_next_non_empty_line() if self._doc.eof(): return False l1 = self._doc.peek().strip() # e.g. Parameters if l1.startswith('.. index::'): return True l2 = self._doc.peek(1).strip() # ---------- or ========== return l2.startswith('-'*len(l1)) or l2.startswith('='*len(l1)) def _strip(self, doc): i = 0 j = 0 for i, line in enumerate(doc): if line.strip(): break for j, line in enumerate(doc[::-1]): if line.strip(): break return doc[i:len(doc)-j] def _read_to_next_section(self): section = self._doc.read_to_next_empty_line() while not self._is_at_section() and not self._doc.eof(): if not self._doc.peek(-1).strip(): # previous line was empty section += [''] section += self._doc.read_to_next_empty_line() return section def _read_sections(self): while not self._doc.eof(): data = self._read_to_next_section() name = data[0].strip() if name.startswith('..'): # index section yield name, data[1:] elif len(data) < 2: yield StopIteration else: yield name, self._strip(data[2:]) def _parse_param_list(self, content, single_element_is_type=False): r = Reader(content) params = [] while not r.eof(): header = r.read().strip() if ' : ' in header: arg_name, arg_type = header.split(' : ')[:2] else: if single_element_is_type: arg_name, arg_type = '', header else: arg_name, arg_type = header, '' desc = r.read_to_next_unindented_line() desc = dedent_lines(desc) desc = strip_blank_lines(desc) params.append(Parameter(arg_name, arg_type, desc)) return params # See also supports the following formats. # # # SPACE* COLON SPACE+ SPACE* # ( COMMA SPACE+ )+ (COMMA | PERIOD)? SPACE* # ( COMMA SPACE+ )* SPACE* COLON SPACE+ SPACE* # is one of # # COLON COLON BACKTICK BACKTICK # where # is a legal function name, and # is any nonempty sequence of word characters. # Examples: func_f1 :meth:`func_h1` :obj:`~baz.obj_r` :class:`class_j` # is a string describing the function. _role = r":(?P\w+):" _funcbacktick = r"`(?P(?:~\w+\.)?[a-zA-Z0-9_\.-]+)`" _funcplain = r"(?P[a-zA-Z0-9_\.-]+)" _funcname = r"(" + _role + _funcbacktick + r"|" + _funcplain + r")" _funcnamenext = _funcname.replace('role', 'rolenext') _funcnamenext = _funcnamenext.replace('name', 'namenext') _description = r"(?P\s*:(\s+(?P\S+.*))?)?\s*$" _func_rgx = re.compile(r"^\s*" + _funcname + r"\s*") _line_rgx = re.compile( r"^\s*" + r"(?P" + # group for all function names _funcname + r"(?P([,]\s+" + _funcnamenext + r")*)" + r")" + # end of "allfuncs" r"(?P[,\.])?" + # Some function lists have a trailing comma (or period) '\s*' _description) # Empty elements are replaced with '..' empty_description = '..' def _parse_see_also(self, content): """ func_name : Descriptive text continued text another_func_name : Descriptive text func_name1, func_name2, :meth:`func_name`, func_name3 """ items = [] def parse_item_name(text): """Match ':role:`name`' or 'name'.""" m = self._func_rgx.match(text) if not m: raise ParseError(f"{text} is not a item name") role = m.group('role') name = m.group('name') if role else m.group('name2') return name, role, m.end() rest = [] for line in content: if not line.strip(): continue line_match = self._line_rgx.match(line) description = None if line_match: description = line_match.group('desc') if line_match.group('trailing') and description: self._error_location( 'Unexpected comma or period after function list at index %d of ' 'line "%s"' % (line_match.end('trailing'), line), error=False) if not description and line.startswith(' '): rest.append(line.strip()) elif line_match: funcs = [] text = line_match.group('allfuncs') while True: if not text.strip(): break name, role, match_end = parse_item_name(text) funcs.append((name, role)) text = text[match_end:].strip() if text and text[0] == ',': text = text[1:].strip() rest = list(filter(None, [description])) items.append((funcs, rest)) else: raise ParseError(f"{line} is not a item name") return items def _parse_index(self, section, content): """ .. index: default :refguide: something, else, and more """ def strip_each_in(lst): return [s.strip() for s in lst] out = {} section = section.split('::') if len(section) > 1: out['default'] = strip_each_in(section[1].split(','))[0] for line in content: line = line.split(':') if len(line) > 2: out[line[1]] = strip_each_in(line[2].split(',')) return out def _parse_summary(self): """Grab signature (if given) and summary""" if self._is_at_section(): return # If several signatures present, take the last one while True: summary = self._doc.read_to_next_empty_line() summary_str = " ".join([s.strip() for s in summary]).strip() compiled = re.compile(r'^([\w., ]+=)?\s*[\w\.]+\(.*\)$') if compiled.match(summary_str): self['Signature'] = summary_str if not self._is_at_section(): continue break if summary is not None: self['Summary'] = summary if not self._is_at_section(): self['Extended Summary'] = self._read_to_next_section() def _parse(self): self._doc.reset() self._parse_summary() sections = list(self._read_sections()) section_names = {section for section, content in sections} has_returns = 'Returns' in section_names has_yields = 'Yields' in section_names # We could do more tests, but we are not. Arbitrarily. if has_returns and has_yields: msg = 'Docstring contains both a Returns and Yields section.' raise ValueError(msg) if not has_yields and 'Receives' in section_names: msg = 'Docstring contains a Receives section but not Yields.' raise ValueError(msg) for (section, content) in sections: if not section.startswith('..'): section = (s.capitalize() for s in section.split(' ')) section = ' '.join(section) if self.get(section): self._error_location(f"The section {section} appears twice") if section in ('Parameters', 'Other Parameters', 'Attributes', 'Methods'): self[section] = self._parse_param_list(content) elif section in ('Returns', 'Yields', 'Raises', 'Warns', 'Receives'): self[section] = self._parse_param_list( content, single_element_is_type=True) elif section.startswith('.. index::'): self['index'] = self._parse_index(section, content) elif section == 'See Also': self['See Also'] = self._parse_see_also(content) else: self[section] = content def _error_location(self, msg, error=True): if hasattr(self, '_obj'): # we know where the docs came from: try: filename = inspect.getsourcefile(self._obj) except TypeError: filename = None msg = msg + f" in the docstring of {self._obj} in {filename}." if error: raise ValueError(msg) else: warn(msg) # string conversion routines def _str_header(self, name, symbol='-'): return [name, len(name)*symbol] def _str_indent(self, doc, indent=4): out = [] for line in doc: out += [' '*indent + line] return out def _str_signature(self): if self['Signature']: return [self['Signature'].replace('*', r'\*')] + [''] else: return [''] def _str_summary(self): if self['Summary']: return self['Summary'] + [''] else: return [] def _str_extended_summary(self): if self['Extended Summary']: return self['Extended Summary'] + [''] else: return [] def _str_param_list(self, name): out = [] if self[name]: out += self._str_header(name) for param in self[name]: parts = [] if param.name: parts.append(param.name) if param.type: parts.append(param.type) out += [' : '.join(parts)] if param.desc and ''.join(param.desc).strip(): out += self._str_indent(param.desc) out += [''] return out def _str_section(self, name): out = [] if self[name]: out += self._str_header(name) out += self[name] out += [''] return out def _str_see_also(self, func_role): if not self['See Also']: return [] out = [] out += self._str_header("See Also") out += [''] last_had_desc = True for funcs, desc in self['See Also']: assert isinstance(funcs, list) links = [] for func, role in funcs: if role: link = f':{role}:`{func}`' elif func_role: link = f':{func_role}:`{func}`' else: link = f"`{func}`_" links.append(link) link = ', '.join(links) out += [link] if desc: out += self._str_indent([' '.join(desc)]) last_had_desc = True else: last_had_desc = False out += self._str_indent([self.empty_description]) if last_had_desc: out += [''] out += [''] return out def _str_index(self): idx = self['index'] out = [] output_index = False default_index = idx.get('default', '') if default_index: output_index = True out += [f'.. index:: {default_index}'] for section, references in idx.items(): if section == 'default': continue output_index = True out += [f" :{section}: {', '.join(references)}"] if output_index: return out else: return '' def __str__(self, func_role=''): out = [] out += self._str_signature() out += self._str_summary() out += self._str_extended_summary() for param_list in ('Parameters', 'Returns', 'Yields', 'Receives', 'Other Parameters', 'Raises', 'Warns'): out += self._str_param_list(param_list) out += self._str_section('Warnings') out += self._str_see_also(func_role) for s in ('Notes', 'References', 'Examples'): out += self._str_section(s) for param_list in ('Attributes', 'Methods'): out += self._str_param_list(param_list) out += self._str_index() return '\n'.join(out) def indent(str, indent=4): indent_str = ' '*indent if str is None: return indent_str lines = str.split('\n') return '\n'.join(indent_str + l for l in lines) def dedent_lines(lines): """Deindent a list of lines maximally""" return textwrap.dedent("\n".join(lines)).split("\n") def header(text, style='-'): return text + '\n' + style*len(text) + '\n' class FunctionDoc(NumpyDocString): def __init__(self, func, role='func', doc=None, config={}): self._f = func self._role = role # e.g. "func" or "meth" if doc is None: if func is None: raise ValueError("No function or docstring given") doc = inspect.getdoc(func) or '' NumpyDocString.__init__(self, doc, config) if not self['Signature'] and func is not None: func, func_name = self.get_func() try: try: signature = str(inspect.signature(func)) except (AttributeError, ValueError): # try to read signature, backward compat for older Python if sys.version_info[0] >= 3: argspec = inspect.getfullargspec(func) else: argspec = inspect.getargspec(func) signature = inspect.formatargspec(*argspec) signature = f'{func_name}{signature}' except TypeError: signature = f'{func_name}()' self['Signature'] = signature def get_func(self): func_name = getattr(self._f, '__name__', self.__class__.__name__) if inspect.isclass(self._f): func = getattr(self._f, '__call__', self._f.__init__) else: func = self._f return func, func_name def __str__(self): out = '' func, func_name = self.get_func() roles = {'func': 'function', 'meth': 'method'} if self._role: if self._role not in roles: print(f"Warning: invalid role {self._role}") out += f".. {roles.get(self._role, '')}:: {func_name}\n \n\n" out += super().__str__(func_role=self._role) return out class ClassDoc(NumpyDocString): extra_public_methods = ['__call__'] def __init__(self, cls, doc=None, modulename='', func_doc=FunctionDoc, config={}): if not inspect.isclass(cls) and cls is not None: raise ValueError(f"Expected a class or None, but got {cls!r}") self._cls = cls if 'sphinx' in sys.modules: from sphinx.ext.autodoc import ALL else: ALL = object() self.show_inherited_members = config.get( 'show_inherited_class_members', True) if modulename and not modulename.endswith('.'): modulename += '.' self._mod = modulename if doc is None: if cls is None: raise ValueError("No class or documentation string given") doc = pydoc.getdoc(cls) NumpyDocString.__init__(self, doc) _members = config.get('members', []) if _members is ALL: _members = None _exclude = config.get('exclude-members', []) if config.get('show_class_members', True) and _exclude is not ALL: def splitlines_x(s): if not s: return [] else: return s.splitlines() for field, items in [('Methods', self.methods), ('Attributes', self.properties)]: if not self[field]: doc_list = [] for name in sorted(items): if (name in _exclude or (_members and name not in _members)): continue try: doc_item = pydoc.getdoc(getattr(self._cls, name)) doc_list.append( Parameter(name, '', splitlines_x(doc_item))) except AttributeError: pass # method doesn't exist self[field] = doc_list @property def methods(self): if self._cls is None: return [] return [name for name, func in inspect.getmembers(self._cls) if ((not name.startswith('_') or name in self.extra_public_methods) and isinstance(func, Callable) and self._is_show_member(name))] @property def properties(self): if self._cls is None: return [] return [name for name, func in inspect.getmembers(self._cls) if (not name.startswith('_') and (func is None or isinstance(func, property) or inspect.isdatadescriptor(func)) and self._is_show_member(name))] def _is_show_member(self, name): if self.show_inherited_members: return True # show all class members if name not in self._cls.__dict__: return False # class member is inherited, we do not show it return True ================================================ FILE: seaborn/external/husl.py ================================================ import operator import math __version__ = "2.1.0" m = [ [3.2406, -1.5372, -0.4986], [-0.9689, 1.8758, 0.0415], [0.0557, -0.2040, 1.0570] ] m_inv = [ [0.4124, 0.3576, 0.1805], [0.2126, 0.7152, 0.0722], [0.0193, 0.1192, 0.9505] ] # Hard-coded D65 illuminant refX = 0.95047 refY = 1.00000 refZ = 1.08883 refU = 0.19784 refV = 0.46834 lab_e = 0.008856 lab_k = 903.3 # Public API def husl_to_rgb(h, s, l): return lch_to_rgb(*husl_to_lch([h, s, l])) def husl_to_hex(h, s, l): return rgb_to_hex(husl_to_rgb(h, s, l)) def rgb_to_husl(r, g, b): return lch_to_husl(rgb_to_lch(r, g, b)) def hex_to_husl(hex): return rgb_to_husl(*hex_to_rgb(hex)) def huslp_to_rgb(h, s, l): return lch_to_rgb(*huslp_to_lch([h, s, l])) def huslp_to_hex(h, s, l): return rgb_to_hex(huslp_to_rgb(h, s, l)) def rgb_to_huslp(r, g, b): return lch_to_huslp(rgb_to_lch(r, g, b)) def hex_to_huslp(hex): return rgb_to_huslp(*hex_to_rgb(hex)) def lch_to_rgb(l, c, h): return xyz_to_rgb(luv_to_xyz(lch_to_luv([l, c, h]))) def rgb_to_lch(r, g, b): return luv_to_lch(xyz_to_luv(rgb_to_xyz([r, g, b]))) def max_chroma(L, H): hrad = math.radians(H) sinH = (math.sin(hrad)) cosH = (math.cos(hrad)) sub1 = (math.pow(L + 16, 3.0) / 1560896.0) sub2 = sub1 if sub1 > 0.008856 else (L / 903.3) result = float("inf") for row in m: m1 = row[0] m2 = row[1] m3 = row[2] top = ((0.99915 * m1 + 1.05122 * m2 + 1.14460 * m3) * sub2) rbottom = (0.86330 * m3 - 0.17266 * m2) lbottom = (0.12949 * m3 - 0.38848 * m1) bottom = (rbottom * sinH + lbottom * cosH) * sub2 for t in (0.0, 1.0): C = (L * (top - 1.05122 * t) / (bottom + 0.17266 * sinH * t)) if C > 0.0 and C < result: result = C return result def _hrad_extremum(L): lhs = (math.pow(L, 3.0) + 48.0 * math.pow(L, 2.0) + 768.0 * L + 4096.0) / 1560896.0 rhs = 1107.0 / 125000.0 sub = lhs if lhs > rhs else 10.0 * L / 9033.0 chroma = float("inf") result = None for row in m: for limit in (0.0, 1.0): [m1, m2, m3] = row top = -3015466475.0 * m3 * sub + 603093295.0 * m2 * sub - 603093295.0 * limit bottom = 1356959916.0 * m1 * sub - 452319972.0 * m3 * sub hrad = math.atan2(top, bottom) # This is a math hack to deal with tan quadrants, I'm too lazy to figure # out how to do this properly if limit == 0.0: hrad += math.pi test = max_chroma(L, math.degrees(hrad)) if test < chroma: chroma = test result = hrad return result def max_chroma_pastel(L): H = math.degrees(_hrad_extremum(L)) return max_chroma(L, H) def dot_product(a, b): return sum(map(operator.mul, a, b)) def f(t): if t > lab_e: return (math.pow(t, 1.0 / 3.0)) else: return (7.787 * t + 16.0 / 116.0) def f_inv(t): if math.pow(t, 3.0) > lab_e: return (math.pow(t, 3.0)) else: return (116.0 * t - 16.0) / lab_k def from_linear(c): if c <= 0.0031308: return 12.92 * c else: return (1.055 * math.pow(c, 1.0 / 2.4) - 0.055) def to_linear(c): a = 0.055 if c > 0.04045: return (math.pow((c + a) / (1.0 + a), 2.4)) else: return (c / 12.92) def rgb_prepare(triple): ret = [] for ch in triple: ch = round(ch, 3) if ch < -0.0001 or ch > 1.0001: raise Exception(f"Illegal RGB value {ch:f}") if ch < 0: ch = 0 if ch > 1: ch = 1 # Fix for Python 3 which by default rounds 4.5 down to 4.0 # instead of Python 2 which is rounded to 5.0 which caused # a couple off by one errors in the tests. Tests now all pass # in Python 2 and Python 3 ret.append(int(round(ch * 255 + 0.001, 0))) return ret def hex_to_rgb(hex): if hex.startswith('#'): hex = hex[1:] r = int(hex[0:2], 16) / 255.0 g = int(hex[2:4], 16) / 255.0 b = int(hex[4:6], 16) / 255.0 return [r, g, b] def rgb_to_hex(triple): [r, g, b] = triple return '#%02x%02x%02x' % tuple(rgb_prepare([r, g, b])) def xyz_to_rgb(triple): xyz = map(lambda row: dot_product(row, triple), m) return list(map(from_linear, xyz)) def rgb_to_xyz(triple): rgbl = list(map(to_linear, triple)) return list(map(lambda row: dot_product(row, rgbl), m_inv)) def xyz_to_luv(triple): X, Y, Z = triple if X == Y == Z == 0.0: return [0.0, 0.0, 0.0] varU = (4.0 * X) / (X + (15.0 * Y) + (3.0 * Z)) varV = (9.0 * Y) / (X + (15.0 * Y) + (3.0 * Z)) L = 116.0 * f(Y / refY) - 16.0 # Black will create a divide-by-zero error if L == 0.0: return [0.0, 0.0, 0.0] U = 13.0 * L * (varU - refU) V = 13.0 * L * (varV - refV) return [L, U, V] def luv_to_xyz(triple): L, U, V = triple if L == 0: return [0.0, 0.0, 0.0] varY = f_inv((L + 16.0) / 116.0) varU = U / (13.0 * L) + refU varV = V / (13.0 * L) + refV Y = varY * refY X = 0.0 - (9.0 * Y * varU) / ((varU - 4.0) * varV - varU * varV) Z = (9.0 * Y - (15.0 * varV * Y) - (varV * X)) / (3.0 * varV) return [X, Y, Z] def luv_to_lch(triple): L, U, V = triple C = (math.pow(math.pow(U, 2) + math.pow(V, 2), (1.0 / 2.0))) hrad = (math.atan2(V, U)) H = math.degrees(hrad) if H < 0.0: H = 360.0 + H return [L, C, H] def lch_to_luv(triple): L, C, H = triple Hrad = math.radians(H) U = (math.cos(Hrad) * C) V = (math.sin(Hrad) * C) return [L, U, V] def husl_to_lch(triple): H, S, L = triple if L > 99.9999999: return [100, 0.0, H] if L < 0.00000001: return [0.0, 0.0, H] mx = max_chroma(L, H) C = mx / 100.0 * S return [L, C, H] def lch_to_husl(triple): L, C, H = triple if L > 99.9999999: return [H, 0.0, 100.0] if L < 0.00000001: return [H, 0.0, 0.0] mx = max_chroma(L, H) S = C / mx * 100.0 return [H, S, L] def huslp_to_lch(triple): H, S, L = triple if L > 99.9999999: return [100, 0.0, H] if L < 0.00000001: return [0.0, 0.0, H] mx = max_chroma_pastel(L) C = mx / 100.0 * S return [L, C, H] def lch_to_huslp(triple): L, C, H = triple if L > 99.9999999: return [H, 0.0, 100.0] if L < 0.00000001: return [H, 0.0, 0.0] mx = max_chroma_pastel(L) S = C / mx * 100.0 return [H, S, L] ================================================ FILE: seaborn/external/kde.py ================================================ """ This module was copied from the scipy project. In the process of copying, some methods were removed because they depended on other parts of scipy (especially on compiled components), allowing seaborn to have a simple and pure Python implementation. These include: - integrate_gaussian - integrate_box - integrate_box_1d - integrate_kde - logpdf - resample Additionally, the numpy.linalg module was substituted for scipy.linalg, and the examples section (with doctests) was removed from the docstring The original scipy license is copied below: Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers. All rights reserved. Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. """ # ------------------------------------------------------------------------------- # # Define classes for (uni/multi)-variate kernel density estimation. # # Currently, only Gaussian kernels are implemented. # # Written by: Robert Kern # # Date: 2004-08-09 # # Modified: 2005-02-10 by Robert Kern. # Contributed to SciPy # 2005-10-07 by Robert Kern. # Some fixes to match the new scipy_core # # Copyright 2004-2005 by Enthought, Inc. # # ------------------------------------------------------------------------------- import numpy as np from numpy import (asarray, atleast_2d, reshape, zeros, newaxis, dot, exp, pi, sqrt, power, atleast_1d, sum, ones, cov) from numpy import linalg __all__ = ['gaussian_kde'] class gaussian_kde: """Representation of a kernel-density estimate using Gaussian kernels. Kernel density estimation is a way to estimate the probability density function (PDF) of a random variable in a non-parametric way. `gaussian_kde` works for both uni-variate and multi-variate data. It includes automatic bandwidth determination. The estimation works best for a unimodal distribution; bimodal or multi-modal distributions tend to be oversmoothed. Parameters ---------- dataset : array_like Datapoints to estimate from. In case of univariate data this is a 1-D array, otherwise a 2-D array with shape (# of dims, # of data). bw_method : str, scalar or callable, optional The method used to calculate the estimator bandwidth. This can be 'scott', 'silverman', a scalar constant or a callable. If a scalar, this will be used directly as `kde.factor`. If a callable, it should take a `gaussian_kde` instance as only parameter and return a scalar. If None (default), 'scott' is used. See Notes for more details. weights : array_like, optional weights of datapoints. This must be the same shape as dataset. If None (default), the samples are assumed to be equally weighted Attributes ---------- dataset : ndarray The dataset with which `gaussian_kde` was initialized. d : int Number of dimensions. n : int Number of datapoints. neff : int Effective number of datapoints. .. versionadded:: 1.2.0 factor : float The bandwidth factor, obtained from `kde.covariance_factor`, with which the covariance matrix is multiplied. covariance : ndarray The covariance matrix of `dataset`, scaled by the calculated bandwidth (`kde.factor`). inv_cov : ndarray The inverse of `covariance`. Methods ------- evaluate __call__ integrate_gaussian integrate_box_1d integrate_box integrate_kde pdf logpdf resample set_bandwidth covariance_factor Notes ----- Bandwidth selection strongly influences the estimate obtained from the KDE (much more so than the actual shape of the kernel). Bandwidth selection can be done by a "rule of thumb", by cross-validation, by "plug-in methods" or by other means; see [3]_, [4]_ for reviews. `gaussian_kde` uses a rule of thumb, the default is Scott's Rule. Scott's Rule [1]_, implemented as `scotts_factor`, is:: n**(-1./(d+4)), with ``n`` the number of data points and ``d`` the number of dimensions. In the case of unequally weighted points, `scotts_factor` becomes:: neff**(-1./(d+4)), with ``neff`` the effective number of datapoints. Silverman's Rule [2]_, implemented as `silverman_factor`, is:: (n * (d + 2) / 4.)**(-1. / (d + 4)). or in the case of unequally weighted points:: (neff * (d + 2) / 4.)**(-1. / (d + 4)). Good general descriptions of kernel density estimation can be found in [1]_ and [2]_, the mathematics for this multi-dimensional implementation can be found in [1]_. With a set of weighted samples, the effective number of datapoints ``neff`` is defined by:: neff = sum(weights)^2 / sum(weights^2) as detailed in [5]_. References ---------- .. [1] D.W. Scott, "Multivariate Density Estimation: Theory, Practice, and Visualization", John Wiley & Sons, New York, Chicester, 1992. .. [2] B.W. Silverman, "Density Estimation for Statistics and Data Analysis", Vol. 26, Monographs on Statistics and Applied Probability, Chapman and Hall, London, 1986. .. [3] B.A. Turlach, "Bandwidth Selection in Kernel Density Estimation: A Review", CORE and Institut de Statistique, Vol. 19, pp. 1-33, 1993. .. [4] D.M. Bashtannyk and R.J. Hyndman, "Bandwidth selection for kernel conditional density estimation", Computational Statistics & Data Analysis, Vol. 36, pp. 279-298, 2001. .. [5] Gray P. G., 1969, Journal of the Royal Statistical Society. Series A (General), 132, 272 """ def __init__(self, dataset, bw_method=None, weights=None): self.dataset = atleast_2d(asarray(dataset)) if not self.dataset.size > 1: raise ValueError("`dataset` input should have multiple elements.") self.d, self.n = self.dataset.shape if weights is not None: self._weights = atleast_1d(weights).astype(float) self._weights /= sum(self._weights) if self.weights.ndim != 1: raise ValueError("`weights` input should be one-dimensional.") if len(self._weights) != self.n: raise ValueError("`weights` input should be of length n") self._neff = 1/sum(self._weights**2) self.set_bandwidth(bw_method=bw_method) def evaluate(self, points): """Evaluate the estimated pdf on a set of points. Parameters ---------- points : (# of dimensions, # of points)-array Alternatively, a (# of dimensions,) vector can be passed in and treated as a single point. Returns ------- values : (# of points,)-array The values at each point. Raises ------ ValueError : if the dimensionality of the input points is different than the dimensionality of the KDE. """ points = atleast_2d(asarray(points)) d, m = points.shape if d != self.d: if d == 1 and m == self.d: # points was passed in as a row vector points = reshape(points, (self.d, 1)) m = 1 else: msg = f"points have dimension {d}, dataset has dimension {self.d}" raise ValueError(msg) output_dtype = np.common_type(self.covariance, points) result = zeros((m,), dtype=output_dtype) whitening = linalg.cholesky(self.inv_cov) scaled_dataset = dot(whitening, self.dataset) scaled_points = dot(whitening, points) if m >= self.n: # there are more points than data, so loop over data for i in range(self.n): diff = scaled_dataset[:, i, newaxis] - scaled_points energy = sum(diff * diff, axis=0) / 2.0 result += self.weights[i]*exp(-energy) else: # loop over points for i in range(m): diff = scaled_dataset - scaled_points[:, i, newaxis] energy = sum(diff * diff, axis=0) / 2.0 result[i] = sum(exp(-energy)*self.weights, axis=0) result = result / self._norm_factor return result __call__ = evaluate def scotts_factor(self): """Compute Scott's factor. Returns ------- s : float Scott's factor. """ return power(self.neff, -1./(self.d+4)) def silverman_factor(self): """Compute the Silverman factor. Returns ------- s : float The silverman factor. """ return power(self.neff*(self.d+2.0)/4.0, -1./(self.d+4)) # Default method to calculate bandwidth, can be overwritten by subclass covariance_factor = scotts_factor covariance_factor.__doc__ = """Computes the coefficient (`kde.factor`) that multiplies the data covariance matrix to obtain the kernel covariance matrix. The default is `scotts_factor`. A subclass can overwrite this method to provide a different method, or set it through a call to `kde.set_bandwidth`.""" def set_bandwidth(self, bw_method=None): """Compute the estimator bandwidth with given method. The new bandwidth calculated after a call to `set_bandwidth` is used for subsequent evaluations of the estimated density. Parameters ---------- bw_method : str, scalar or callable, optional The method used to calculate the estimator bandwidth. This can be 'scott', 'silverman', a scalar constant or a callable. If a scalar, this will be used directly as `kde.factor`. If a callable, it should take a `gaussian_kde` instance as only parameter and return a scalar. If None (default), nothing happens; the current `kde.covariance_factor` method is kept. Notes ----- .. versionadded:: 0.11 """ if bw_method is None: pass elif bw_method == 'scott': self.covariance_factor = self.scotts_factor elif bw_method == 'silverman': self.covariance_factor = self.silverman_factor elif np.isscalar(bw_method) and not isinstance(bw_method, str): self._bw_method = 'use constant' self.covariance_factor = lambda: bw_method elif callable(bw_method): self._bw_method = bw_method self.covariance_factor = lambda: self._bw_method(self) else: msg = "`bw_method` should be 'scott', 'silverman', a scalar " \ "or a callable." raise ValueError(msg) self._compute_covariance() def _compute_covariance(self): """Computes the covariance matrix for each Gaussian kernel using covariance_factor(). """ self.factor = self.covariance_factor() # Cache covariance and inverse covariance of the data if not hasattr(self, '_data_inv_cov'): self._data_covariance = atleast_2d(cov(self.dataset, rowvar=1, bias=False, aweights=self.weights)) self._data_inv_cov = linalg.inv(self._data_covariance) self.covariance = self._data_covariance * self.factor**2 self.inv_cov = self._data_inv_cov / self.factor**2 self._norm_factor = sqrt(linalg.det(2*pi*self.covariance)) def pdf(self, x): """ Evaluate the estimated pdf on a provided set of points. Notes ----- This is an alias for `gaussian_kde.evaluate`. See the ``evaluate`` docstring for more details. """ return self.evaluate(x) @property def weights(self): try: return self._weights except AttributeError: self._weights = ones(self.n)/self.n return self._weights @property def neff(self): try: return self._neff except AttributeError: self._neff = 1/sum(self.weights**2) return self._neff ================================================ FILE: seaborn/external/version.py ================================================ """Extract reference documentation from the pypa/packaging source tree. In the process of copying, some unused methods / classes were removed. These include: - parse() - anything involving LegacyVersion This software is made available under the terms of *either* of the licenses found in LICENSE.APACHE or LICENSE.BSD. Contributions to this software is made under the terms of *both* these licenses. Vendored from: - https://github.com/pypa/packaging/ - commit ba07d8287b4554754ac7178d177033ea3f75d489 (09/09/2021) """ # This file is dual licensed under the terms of the Apache License, Version # 2.0, and the BSD License. See the LICENSE file in the root of this repository # for complete details. import collections import itertools import re from typing import Callable, Optional, SupportsInt, Tuple, Union __all__ = ["Version", "InvalidVersion", "VERSION_PATTERN"] # Vendored from https://github.com/pypa/packaging/blob/main/packaging/_structures.py class InfinityType: def __repr__(self) -> str: return "Infinity" def __hash__(self) -> int: return hash(repr(self)) def __lt__(self, other: object) -> bool: return False def __le__(self, other: object) -> bool: return False def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) def __ne__(self, other: object) -> bool: return not isinstance(other, self.__class__) def __gt__(self, other: object) -> bool: return True def __ge__(self, other: object) -> bool: return True def __neg__(self: object) -> "NegativeInfinityType": return NegativeInfinity Infinity = InfinityType() class NegativeInfinityType: def __repr__(self) -> str: return "-Infinity" def __hash__(self) -> int: return hash(repr(self)) def __lt__(self, other: object) -> bool: return True def __le__(self, other: object) -> bool: return True def __eq__(self, other: object) -> bool: return isinstance(other, self.__class__) def __ne__(self, other: object) -> bool: return not isinstance(other, self.__class__) def __gt__(self, other: object) -> bool: return False def __ge__(self, other: object) -> bool: return False def __neg__(self: object) -> InfinityType: return Infinity NegativeInfinity = NegativeInfinityType() # Vendored from https://github.com/pypa/packaging/blob/main/packaging/version.py InfiniteTypes = Union[InfinityType, NegativeInfinityType] PrePostDevType = Union[InfiniteTypes, Tuple[str, int]] SubLocalType = Union[InfiniteTypes, int, str] LocalType = Union[ NegativeInfinityType, Tuple[ Union[ SubLocalType, Tuple[SubLocalType, str], Tuple[NegativeInfinityType, SubLocalType], ], ..., ], ] CmpKey = Tuple[ int, Tuple[int, ...], PrePostDevType, PrePostDevType, PrePostDevType, LocalType ] LegacyCmpKey = Tuple[int, Tuple[str, ...]] VersionComparisonMethod = Callable[ [Union[CmpKey, LegacyCmpKey], Union[CmpKey, LegacyCmpKey]], bool ] _Version = collections.namedtuple( "_Version", ["epoch", "release", "dev", "pre", "post", "local"] ) class InvalidVersion(ValueError): """ An invalid version was found, users should refer to PEP 440. """ class _BaseVersion: _key: Union[CmpKey, LegacyCmpKey] def __hash__(self) -> int: return hash(self._key) # Please keep the duplicated `isinstance` check # in the six comparisons hereunder # unless you find a way to avoid adding overhead function calls. def __lt__(self, other: "_BaseVersion") -> bool: if not isinstance(other, _BaseVersion): return NotImplemented return self._key < other._key def __le__(self, other: "_BaseVersion") -> bool: if not isinstance(other, _BaseVersion): return NotImplemented return self._key <= other._key def __eq__(self, other: object) -> bool: if not isinstance(other, _BaseVersion): return NotImplemented return self._key == other._key def __ge__(self, other: "_BaseVersion") -> bool: if not isinstance(other, _BaseVersion): return NotImplemented return self._key >= other._key def __gt__(self, other: "_BaseVersion") -> bool: if not isinstance(other, _BaseVersion): return NotImplemented return self._key > other._key def __ne__(self, other: object) -> bool: if not isinstance(other, _BaseVersion): return NotImplemented return self._key != other._key # Deliberately not anchored to the start and end of the string, to make it # easier for 3rd party code to reuse VERSION_PATTERN = r""" v? (?: (?:(?P[0-9]+)!)? # epoch (?P[0-9]+(?:\.[0-9]+)*) # release segment (?P
                                          # pre-release
            [-_\.]?
            (?P(a|b|c|rc|alpha|beta|pre|preview))
            [-_\.]?
            (?P[0-9]+)?
        )?
        (?P                                         # post release
            (?:-(?P[0-9]+))
            |
            (?:
                [-_\.]?
                (?Ppost|rev|r)
                [-_\.]?
                (?P[0-9]+)?
            )
        )?
        (?P                                          # dev release
            [-_\.]?
            (?Pdev)
            [-_\.]?
            (?P[0-9]+)?
        )?
    )
    (?:\+(?P[a-z0-9]+(?:[-_\.][a-z0-9]+)*))?       # local version
"""


class Version(_BaseVersion):

    _regex = re.compile(r"^\s*" + VERSION_PATTERN + r"\s*$", re.VERBOSE | re.IGNORECASE)

    def __init__(self, version: str) -> None:

        # Validate the version and parse it into pieces
        match = self._regex.search(version)
        if not match:
            raise InvalidVersion(f"Invalid version: '{version}'")

        # Store the parsed out pieces of the version
        self._version = _Version(
            epoch=int(match.group("epoch")) if match.group("epoch") else 0,
            release=tuple(int(i) for i in match.group("release").split(".")),
            pre=_parse_letter_version(match.group("pre_l"), match.group("pre_n")),
            post=_parse_letter_version(
                match.group("post_l"), match.group("post_n1") or match.group("post_n2")
            ),
            dev=_parse_letter_version(match.group("dev_l"), match.group("dev_n")),
            local=_parse_local_version(match.group("local")),
        )

        # Generate a key which will be used for sorting
        self._key = _cmpkey(
            self._version.epoch,
            self._version.release,
            self._version.pre,
            self._version.post,
            self._version.dev,
            self._version.local,
        )

    def __repr__(self) -> str:
        return f""

    def __str__(self) -> str:
        parts = []

        # Epoch
        if self.epoch != 0:
            parts.append(f"{self.epoch}!")

        # Release segment
        parts.append(".".join(str(x) for x in self.release))

        # Pre-release
        if self.pre is not None:
            parts.append("".join(str(x) for x in self.pre))

        # Post-release
        if self.post is not None:
            parts.append(f".post{self.post}")

        # Development release
        if self.dev is not None:
            parts.append(f".dev{self.dev}")

        # Local version segment
        if self.local is not None:
            parts.append(f"+{self.local}")

        return "".join(parts)

    @property
    def epoch(self) -> int:
        _epoch: int = self._version.epoch
        return _epoch

    @property
    def release(self) -> Tuple[int, ...]:
        _release: Tuple[int, ...] = self._version.release
        return _release

    @property
    def pre(self) -> Optional[Tuple[str, int]]:
        _pre: Optional[Tuple[str, int]] = self._version.pre
        return _pre

    @property
    def post(self) -> Optional[int]:
        return self._version.post[1] if self._version.post else None

    @property
    def dev(self) -> Optional[int]:
        return self._version.dev[1] if self._version.dev else None

    @property
    def local(self) -> Optional[str]:
        if self._version.local:
            return ".".join(str(x) for x in self._version.local)
        else:
            return None

    @property
    def public(self) -> str:
        return str(self).split("+", 1)[0]

    @property
    def base_version(self) -> str:
        parts = []

        # Epoch
        if self.epoch != 0:
            parts.append(f"{self.epoch}!")

        # Release segment
        parts.append(".".join(str(x) for x in self.release))

        return "".join(parts)

    @property
    def is_prerelease(self) -> bool:
        return self.dev is not None or self.pre is not None

    @property
    def is_postrelease(self) -> bool:
        return self.post is not None

    @property
    def is_devrelease(self) -> bool:
        return self.dev is not None

    @property
    def major(self) -> int:
        return self.release[0] if len(self.release) >= 1 else 0

    @property
    def minor(self) -> int:
        return self.release[1] if len(self.release) >= 2 else 0

    @property
    def micro(self) -> int:
        return self.release[2] if len(self.release) >= 3 else 0


def _parse_letter_version(
    letter: str, number: Union[str, bytes, SupportsInt]
) -> Optional[Tuple[str, int]]:

    if letter:
        # We consider there to be an implicit 0 in a pre-release if there is
        # not a numeral associated with it.
        if number is None:
            number = 0

        # We normalize any letters to their lower case form
        letter = letter.lower()

        # We consider some words to be alternate spellings of other words and
        # in those cases we want to normalize the spellings to our preferred
        # spelling.
        if letter == "alpha":
            letter = "a"
        elif letter == "beta":
            letter = "b"
        elif letter in ["c", "pre", "preview"]:
            letter = "rc"
        elif letter in ["rev", "r"]:
            letter = "post"

        return letter, int(number)
    if not letter and number:
        # We assume if we are given a number, but we are not given a letter
        # then this is using the implicit post release syntax (e.g. 1.0-1)
        letter = "post"

        return letter, int(number)

    return None


_local_version_separators = re.compile(r"[\._-]")


def _parse_local_version(local: str) -> Optional[LocalType]:
    """
    Takes a string like abc.1.twelve and turns it into ("abc", 1, "twelve").
    """
    if local is not None:
        return tuple(
            part.lower() if not part.isdigit() else int(part)
            for part in _local_version_separators.split(local)
        )
    return None


def _cmpkey(
    epoch: int,
    release: Tuple[int, ...],
    pre: Optional[Tuple[str, int]],
    post: Optional[Tuple[str, int]],
    dev: Optional[Tuple[str, int]],
    local: Optional[Tuple[SubLocalType]],
) -> CmpKey:

    # When we compare a release version, we want to compare it with all of the
    # trailing zeros removed. So we'll use a reverse the list, drop all the now
    # leading zeros until we come to something non zero, then take the rest
    # re-reverse it back into the correct order and make it a tuple and use
    # that for our sorting key.
    _release = tuple(
        reversed(list(itertools.dropwhile(lambda x: x == 0, reversed(release))))
    )

    # We need to "trick" the sorting algorithm to put 1.0.dev0 before 1.0a0.
    # We'll do this by abusing the pre segment, but we _only_ want to do this
    # if there is not a pre or a post segment. If we have one of those then
    # the normal sorting rules will handle this case correctly.
    if pre is None and post is None and dev is not None:
        _pre: PrePostDevType = NegativeInfinity
    # Versions without a pre-release (except as noted above) should sort after
    # those with one.
    elif pre is None:
        _pre = Infinity
    else:
        _pre = pre

    # Versions without a post segment should sort before those with one.
    if post is None:
        _post: PrePostDevType = NegativeInfinity

    else:
        _post = post

    # Versions without a development segment should sort after those with one.
    if dev is None:
        _dev: PrePostDevType = Infinity

    else:
        _dev = dev

    if local is None:
        # Versions without a local segment should sort before those with one.
        _local: LocalType = NegativeInfinity
    else:
        # Versions with a local segment need that segment parsed to implement
        # the sorting rules in PEP440.
        # - Alpha numeric segments sort before numeric segments
        # - Alpha numeric segments sort lexicographically
        # - Numeric segments sort numerically
        # - Shorter versions sort before longer versions when the prefixes
        #   match exactly
        _local = tuple(
            (i, "") if isinstance(i, int) else (NegativeInfinity, i) for i in local
        )

    return epoch, _release, _pre, _post, _dev, _local


================================================
FILE: seaborn/matrix.py
================================================
"""Functions to visualize matrices of data."""
import warnings

import matplotlib as mpl
from matplotlib.collections import LineCollection
import matplotlib.pyplot as plt
from matplotlib import gridspec
import numpy as np
import pandas as pd
try:
    from scipy.cluster import hierarchy
    _no_scipy = False
except ImportError:
    _no_scipy = True

from . import cm
from .axisgrid import Grid
from ._compat import get_colormap
from .utils import (
    despine,
    axis_ticklabels_overlap,
    relative_luminance,
    to_utf8,
    _draw_figure,
)


__all__ = ["heatmap", "clustermap"]


def _index_to_label(index):
    """Convert a pandas index or multiindex to an axis label."""
    if isinstance(index, pd.MultiIndex):
        return "-".join(map(to_utf8, index.names))
    else:
        return index.name


def _index_to_ticklabels(index):
    """Convert a pandas index or multiindex into ticklabels."""
    if isinstance(index, pd.MultiIndex):
        return ["-".join(map(to_utf8, i)) for i in index.values]
    else:
        return index.values


def _convert_colors(colors):
    """Convert either a list of colors or nested lists of colors to RGB."""
    to_rgb = mpl.colors.to_rgb

    try:
        to_rgb(colors[0])
        # If this works, there is only one level of colors
        return list(map(to_rgb, colors))
    except ValueError:
        # If we get here, we have nested lists
        return [list(map(to_rgb, color_list)) for color_list in colors]


def _matrix_mask(data, mask):
    """Ensure that data and mask are compatible and add missing values.

    Values will be plotted for cells where ``mask`` is ``False``.

    ``data`` is expected to be a DataFrame; ``mask`` can be an array or
    a DataFrame.

    """
    if mask is None:
        mask = np.zeros(data.shape, bool)

    if isinstance(mask, pd.DataFrame):
        # For DataFrame masks, ensure that semantic labels match data
        if not mask.index.equals(data.index) \
           and mask.columns.equals(data.columns):
            err = "Mask must have the same index and columns as data."
            raise ValueError(err)
    elif hasattr(mask, "__array__"):
        mask = np.asarray(mask)
        # For array masks, ensure that shape matches data then convert
        if mask.shape != data.shape:
            raise ValueError("Mask must have the same shape as data.")

        mask = pd.DataFrame(mask,
                            index=data.index,
                            columns=data.columns,
                            dtype=bool)

    # Add any cells with missing data to the mask
    # This works around an issue where `plt.pcolormesh` doesn't represent
    # missing data properly
    mask = mask | pd.isnull(data)

    return mask


class _HeatMapper:
    """Draw a heatmap plot of a matrix with nice labels and colormaps."""

    def __init__(self, data, vmin, vmax, cmap, center, robust, annot, fmt,
                 annot_kws, cbar, cbar_kws,
                 xticklabels=True, yticklabels=True, mask=None):
        """Initialize the plotting object."""
        # We always want to have a DataFrame with semantic information
        # and an ndarray to pass to matplotlib
        if isinstance(data, pd.DataFrame):
            plot_data = data.values
        else:
            plot_data = np.asarray(data)
            data = pd.DataFrame(plot_data)

        # Validate the mask and convert to DataFrame
        mask = _matrix_mask(data, mask)

        plot_data = np.ma.masked_where(np.asarray(mask), plot_data)

        # Get good names for the rows and columns
        xtickevery = 1
        if isinstance(xticklabels, int):
            xtickevery = xticklabels
            xticklabels = _index_to_ticklabels(data.columns)
        elif xticklabels is True:
            xticklabels = _index_to_ticklabels(data.columns)
        elif xticklabels is False:
            xticklabels = []

        ytickevery = 1
        if isinstance(yticklabels, int):
            ytickevery = yticklabels
            yticklabels = _index_to_ticklabels(data.index)
        elif yticklabels is True:
            yticklabels = _index_to_ticklabels(data.index)
        elif yticklabels is False:
            yticklabels = []

        if not len(xticklabels):
            self.xticks = []
            self.xticklabels = []
        elif isinstance(xticklabels, str) and xticklabels == "auto":
            self.xticks = "auto"
            self.xticklabels = _index_to_ticklabels(data.columns)
        else:
            self.xticks, self.xticklabels = self._skip_ticks(xticklabels,
                                                             xtickevery)

        if not len(yticklabels):
            self.yticks = []
            self.yticklabels = []
        elif isinstance(yticklabels, str) and yticklabels == "auto":
            self.yticks = "auto"
            self.yticklabels = _index_to_ticklabels(data.index)
        else:
            self.yticks, self.yticklabels = self._skip_ticks(yticklabels,
                                                             ytickevery)

        # Get good names for the axis labels
        xlabel = _index_to_label(data.columns)
        ylabel = _index_to_label(data.index)
        self.xlabel = xlabel if xlabel is not None else ""
        self.ylabel = ylabel if ylabel is not None else ""

        # Determine good default values for the colormapping
        self._determine_cmap_params(plot_data, vmin, vmax,
                                    cmap, center, robust)

        # Sort out the annotations
        if annot is None or annot is False:
            annot = False
            annot_data = None
        else:
            if isinstance(annot, bool):
                annot_data = plot_data
            else:
                annot_data = np.asarray(annot)
                if annot_data.shape != plot_data.shape:
                    err = "`data` and `annot` must have same shape."
                    raise ValueError(err)
            annot = True

        # Save other attributes to the object
        self.data = data
        self.plot_data = plot_data

        self.annot = annot
        self.annot_data = annot_data

        self.fmt = fmt
        self.annot_kws = {} if annot_kws is None else annot_kws.copy()
        self.cbar = cbar
        self.cbar_kws = {} if cbar_kws is None else cbar_kws.copy()

    def _determine_cmap_params(self, plot_data, vmin, vmax,
                               cmap, center, robust):
        """Use some heuristics to set good defaults for colorbar and range."""

        # plot_data is a np.ma.array instance
        calc_data = plot_data.astype(float).filled(np.nan)
        if vmin is None:
            if robust:
                vmin = np.nanpercentile(calc_data, 2)
            else:
                vmin = np.nanmin(calc_data)
        if vmax is None:
            if robust:
                vmax = np.nanpercentile(calc_data, 98)
            else:
                vmax = np.nanmax(calc_data)
        self.vmin, self.vmax = vmin, vmax

        # Choose default colormaps if not provided
        if cmap is None:
            if center is None:
                self.cmap = cm.rocket
            else:
                self.cmap = cm.icefire
        elif isinstance(cmap, str):
            self.cmap = get_colormap(cmap)
        elif isinstance(cmap, list):
            self.cmap = mpl.colors.ListedColormap(cmap)
        else:
            self.cmap = cmap

        # Recenter a divergent colormap
        if center is not None:

            # Copy bad values
            # in mpl<3.2 only masked values are honored with "bad" color spec
            # (see https://github.com/matplotlib/matplotlib/pull/14257)
            bad = self.cmap(np.ma.masked_invalid([np.nan]))[0]

            # under/over values are set for sure when cmap extremes
            # do not map to the same color as +-inf
            under = self.cmap(-np.inf)
            over = self.cmap(np.inf)
            under_set = under != self.cmap(0)
            over_set = over != self.cmap(self.cmap.N - 1)

            vrange = max(vmax - center, center - vmin)
            normlize = mpl.colors.Normalize(center - vrange, center + vrange)
            cmin, cmax = normlize([vmin, vmax])
            cc = np.linspace(cmin, cmax, 256)
            self.cmap = mpl.colors.ListedColormap(self.cmap(cc))
            self.cmap.set_bad(bad)
            if under_set:
                self.cmap.set_under(under)
            if over_set:
                self.cmap.set_over(over)

    def _annotate_heatmap(self, ax, mesh):
        """Add textual labels with the value in each cell."""
        mesh.update_scalarmappable()
        height, width = self.annot_data.shape
        xpos, ypos = np.meshgrid(np.arange(width) + .5, np.arange(height) + .5)
        for x, y, m, color, val in zip(xpos.flat, ypos.flat,
                                       mesh.get_array().flat, mesh.get_facecolors(),
                                       self.annot_data.flat):
            if m is not np.ma.masked:
                lum = relative_luminance(color)
                text_color = ".15" if lum > .408 else "w"
                annotation = ("{:" + self.fmt + "}").format(val)
                text_kwargs = dict(color=text_color, ha="center", va="center")
                text_kwargs.update(self.annot_kws)
                ax.text(x, y, annotation, **text_kwargs)

    def _skip_ticks(self, labels, tickevery):
        """Return ticks and labels at evenly spaced intervals."""
        n = len(labels)
        if tickevery == 0:
            ticks, labels = [], []
        elif tickevery == 1:
            ticks, labels = np.arange(n) + .5, labels
        else:
            start, end, step = 0, n, tickevery
            ticks = np.arange(start, end, step) + .5
            labels = labels[start:end:step]
        return ticks, labels

    def _auto_ticks(self, ax, labels, axis):
        """Determine ticks and ticklabels that minimize overlap."""
        transform = ax.figure.dpi_scale_trans.inverted()
        bbox = ax.get_window_extent().transformed(transform)
        size = [bbox.width, bbox.height][axis]
        axis = [ax.xaxis, ax.yaxis][axis]
        tick, = axis.set_ticks([0])
        fontsize = tick.label1.get_size()
        max_ticks = int(size // (fontsize / 72))
        if max_ticks < 1:
            return [], []
        tick_every = len(labels) // max_ticks + 1
        tick_every = 1 if tick_every == 0 else tick_every
        ticks, labels = self._skip_ticks(labels, tick_every)
        return ticks, labels

    def plot(self, ax, cax, kws):
        """Draw the heatmap on the provided Axes."""
        # Remove all the Axes spines
        despine(ax=ax, left=True, bottom=True)

        # setting vmin/vmax in addition to norm is deprecated
        # so avoid setting if norm is set
        if kws.get("norm") is None:
            kws.setdefault("vmin", self.vmin)
            kws.setdefault("vmax", self.vmax)

        # Draw the heatmap
        mesh = ax.pcolormesh(self.plot_data, cmap=self.cmap, **kws)

        # Set the axis limits
        ax.set(xlim=(0, self.data.shape[1]), ylim=(0, self.data.shape[0]))

        # Invert the y axis to show the plot in matrix form
        ax.invert_yaxis()

        # Possibly add a colorbar
        if self.cbar:
            cb = ax.figure.colorbar(mesh, cax, ax, **self.cbar_kws)
            cb.outline.set_linewidth(0)
            # If rasterized is passed to pcolormesh, also rasterize the
            # colorbar to avoid white lines on the PDF rendering
            if kws.get('rasterized', False):
                cb.solids.set_rasterized(True)

        # Add row and column labels
        if isinstance(self.xticks, str) and self.xticks == "auto":
            xticks, xticklabels = self._auto_ticks(ax, self.xticklabels, 0)
        else:
            xticks, xticklabels = self.xticks, self.xticklabels

        if isinstance(self.yticks, str) and self.yticks == "auto":
            yticks, yticklabels = self._auto_ticks(ax, self.yticklabels, 1)
        else:
            yticks, yticklabels = self.yticks, self.yticklabels

        ax.set(xticks=xticks, yticks=yticks)
        xtl = ax.set_xticklabels(xticklabels)
        ytl = ax.set_yticklabels(yticklabels, rotation="vertical")
        plt.setp(ytl, va="center")  # GH2484

        # Possibly rotate them if they overlap
        _draw_figure(ax.figure)

        if axis_ticklabels_overlap(xtl):
            plt.setp(xtl, rotation="vertical")
        if axis_ticklabels_overlap(ytl):
            plt.setp(ytl, rotation="horizontal")

        # Add the axis labels
        ax.set(xlabel=self.xlabel, ylabel=self.ylabel)

        # Annotate the cells with the formatted values
        if self.annot:
            self._annotate_heatmap(ax, mesh)


def heatmap(
    data, *,
    vmin=None, vmax=None, cmap=None, center=None, robust=False,
    annot=None, fmt=".2g", annot_kws=None,
    linewidths=0, linecolor="white",
    cbar=True, cbar_kws=None, cbar_ax=None,
    square=False, xticklabels="auto", yticklabels="auto",
    mask=None, ax=None,
    **kwargs
):
    """Plot rectangular data as a color-encoded matrix.

    This is an Axes-level function and will draw the heatmap into the
    currently-active Axes if none is provided to the ``ax`` argument.  Part of
    this Axes space will be taken and used to plot a colormap, unless ``cbar``
    is False or a separate Axes is provided to ``cbar_ax``.

    Parameters
    ----------
    data : rectangular dataset
        2D dataset that can be coerced into an ndarray. If a Pandas DataFrame
        is provided, the index/column information will be used to label the
        columns and rows.
    vmin, vmax : floats, optional
        Values to anchor the colormap, otherwise they are inferred from the
        data and other keyword arguments.
    cmap : matplotlib colormap name or object, or list of colors, optional
        The mapping from data values to color space. If not provided, the
        default will depend on whether ``center`` is set.
    center : float, optional
        The value at which to center the colormap when plotting divergent data.
        Using this parameter will change the default ``cmap`` if none is
        specified.
    robust : bool, optional
        If True and ``vmin`` or ``vmax`` are absent, the colormap range is
        computed with robust quantiles instead of the extreme values.
    annot : bool or rectangular dataset, optional
        If True, write the data value in each cell. If an array-like with the
        same shape as ``data``, then use this to annotate the heatmap instead
        of the data. Note that DataFrames will match on position, not index.
    fmt : str, optional
        String formatting code to use when adding annotations.
    annot_kws : dict of key, value mappings, optional
        Keyword arguments for :meth:`matplotlib.axes.Axes.text` when ``annot``
        is True.
    linewidths : float, optional
        Width of the lines that will divide each cell.
    linecolor : color, optional
        Color of the lines that will divide each cell.
    cbar : bool, optional
        Whether to draw a colorbar.
    cbar_kws : dict of key, value mappings, optional
        Keyword arguments for :meth:`matplotlib.figure.Figure.colorbar`.
    cbar_ax : matplotlib Axes, optional
        Axes in which to draw the colorbar, otherwise take space from the
        main Axes.
    square : bool, optional
        If True, set the Axes aspect to "equal" so each cell will be
        square-shaped.
    xticklabels, yticklabels : "auto", bool, list-like, or int, optional
        If True, plot the column names of the dataframe. If False, don't plot
        the column names. If list-like, plot these alternate labels as the
        xticklabels. If an integer, use the column names but plot only every
        n label. If "auto", try to densely plot non-overlapping labels.
    mask : bool array or DataFrame, optional
        If passed, data will not be shown in cells where ``mask`` is True.
        Cells with missing values are automatically masked.
    ax : matplotlib Axes, optional
        Axes in which to draw the plot, otherwise use the currently-active
        Axes.
    kwargs : other keyword arguments
        All other keyword arguments are passed to
        :meth:`matplotlib.axes.Axes.pcolormesh`.

    Returns
    -------
    ax : matplotlib Axes
        Axes object with the heatmap.

    See Also
    --------
    clustermap : Plot a matrix using hierarchical clustering to arrange the
                 rows and columns.

    Examples
    --------

    .. include:: ../docstrings/heatmap.rst

    """
    # Initialize the plotter object
    plotter = _HeatMapper(data, vmin, vmax, cmap, center, robust, annot, fmt,
                          annot_kws, cbar, cbar_kws, xticklabels,
                          yticklabels, mask)

    # Add the pcolormesh kwargs here
    kwargs["linewidths"] = linewidths
    kwargs["edgecolor"] = linecolor

    # Draw the plot and return the Axes
    if ax is None:
        ax = plt.gca()
    if square:
        ax.set_aspect("equal")
    plotter.plot(ax, cbar_ax, kwargs)
    return ax


class _DendrogramPlotter:
    """Object for drawing tree of similarities between data rows/columns"""

    def __init__(self, data, linkage, metric, method, axis, label, rotate):
        """Plot a dendrogram of the relationships between the columns of data

        Parameters
        ----------
        data : pandas.DataFrame
            Rectangular data
        """
        self.axis = axis
        if self.axis == 1:
            data = data.T

        if isinstance(data, pd.DataFrame):
            array = data.values
        else:
            array = np.asarray(data)
            data = pd.DataFrame(array)

        self.array = array
        self.data = data

        self.shape = self.data.shape
        self.metric = metric
        self.method = method
        self.axis = axis
        self.label = label
        self.rotate = rotate

        if linkage is None:
            self.linkage = self.calculated_linkage
        else:
            self.linkage = linkage
        self.dendrogram = self.calculate_dendrogram()

        # Dendrogram ends are always at multiples of 5, who knows why
        ticks = 10 * np.arange(self.data.shape[0]) + 5

        if self.label:
            ticklabels = _index_to_ticklabels(self.data.index)
            ticklabels = [ticklabels[i] for i in self.reordered_ind]
            if self.rotate:
                self.xticks = []
                self.yticks = ticks
                self.xticklabels = []

                self.yticklabels = ticklabels
                self.ylabel = _index_to_label(self.data.index)
                self.xlabel = ''
            else:
                self.xticks = ticks
                self.yticks = []
                self.xticklabels = ticklabels
                self.yticklabels = []
                self.ylabel = ''
                self.xlabel = _index_to_label(self.data.index)
        else:
            self.xticks, self.yticks = [], []
            self.yticklabels, self.xticklabels = [], []
            self.xlabel, self.ylabel = '', ''

        self.dependent_coord = self.dendrogram['dcoord']
        self.independent_coord = self.dendrogram['icoord']

    def _calculate_linkage_scipy(self):
        linkage = hierarchy.linkage(self.array, method=self.method,
                                    metric=self.metric)
        return linkage

    def _calculate_linkage_fastcluster(self):
        import fastcluster
        # Fastcluster has a memory-saving vectorized version, but only
        # with certain linkage methods, and mostly with euclidean metric
        # vector_methods = ('single', 'centroid', 'median', 'ward')
        euclidean_methods = ('centroid', 'median', 'ward')
        euclidean = self.metric == 'euclidean' and self.method in \
            euclidean_methods
        if euclidean or self.method == 'single':
            return fastcluster.linkage_vector(self.array,
                                              method=self.method,
                                              metric=self.metric)
        else:
            linkage = fastcluster.linkage(self.array, method=self.method,
                                          metric=self.metric)
            return linkage

    @property
    def calculated_linkage(self):

        try:
            return self._calculate_linkage_fastcluster()
        except ImportError:
            if np.prod(self.shape) >= 10000:
                msg = ("Clustering large matrix with scipy. Installing "
                       "`fastcluster` may give better performance.")
                warnings.warn(msg)

        return self._calculate_linkage_scipy()

    def calculate_dendrogram(self):
        """Calculates a dendrogram based on the linkage matrix

        Made a separate function, not a property because don't want to
        recalculate the dendrogram every time it is accessed.

        Returns
        -------
        dendrogram : dict
            Dendrogram dictionary as returned by scipy.cluster.hierarchy
            .dendrogram. The important key-value pairing is
            "reordered_ind" which indicates the re-ordering of the matrix
        """
        return hierarchy.dendrogram(self.linkage, no_plot=True,
                                    color_threshold=-np.inf)

    @property
    def reordered_ind(self):
        """Indices of the matrix, reordered by the dendrogram"""
        return self.dendrogram['leaves']

    def plot(self, ax, tree_kws):
        """Plots a dendrogram of the similarities between data on the axes

        Parameters
        ----------
        ax : matplotlib.axes.Axes
            Axes object upon which the dendrogram is plotted

        """
        tree_kws = {} if tree_kws is None else tree_kws.copy()
        tree_kws.setdefault("linewidths", .5)
        tree_kws.setdefault("colors", tree_kws.pop("color", (.2, .2, .2)))

        if self.rotate and self.axis == 0:
            coords = zip(self.dependent_coord, self.independent_coord)
        else:
            coords = zip(self.independent_coord, self.dependent_coord)
        lines = LineCollection([list(zip(x, y)) for x, y in coords],
                               **tree_kws)

        ax.add_collection(lines)
        number_of_leaves = len(self.reordered_ind)
        max_dependent_coord = max(map(max, self.dependent_coord))

        if self.rotate:
            ax.yaxis.set_ticks_position('right')

            # Constants 10 and 1.05 come from
            # `scipy.cluster.hierarchy._plot_dendrogram`
            ax.set_ylim(0, number_of_leaves * 10)
            ax.set_xlim(0, max_dependent_coord * 1.05)

            ax.invert_xaxis()
            ax.invert_yaxis()
        else:
            # Constants 10 and 1.05 come from
            # `scipy.cluster.hierarchy._plot_dendrogram`
            ax.set_xlim(0, number_of_leaves * 10)
            ax.set_ylim(0, max_dependent_coord * 1.05)

        despine(ax=ax, bottom=True, left=True)

        ax.set(xticks=self.xticks, yticks=self.yticks,
               xlabel=self.xlabel, ylabel=self.ylabel)
        xtl = ax.set_xticklabels(self.xticklabels)
        ytl = ax.set_yticklabels(self.yticklabels, rotation='vertical')

        # Force a draw of the plot to avoid matplotlib window error
        _draw_figure(ax.figure)

        if len(ytl) > 0 and axis_ticklabels_overlap(ytl):
            plt.setp(ytl, rotation="horizontal")
        if len(xtl) > 0 and axis_ticklabels_overlap(xtl):
            plt.setp(xtl, rotation="vertical")
        return self


def dendrogram(
    data, *,
    linkage=None, axis=1, label=True, metric='euclidean',
    method='average', rotate=False, tree_kws=None, ax=None
):
    """Draw a tree diagram of relationships within a matrix

    Parameters
    ----------
    data : pandas.DataFrame
        Rectangular data
    linkage : numpy.array, optional
        Linkage matrix
    axis : int, optional
        Which axis to use to calculate linkage. 0 is rows, 1 is columns.
    label : bool, optional
        If True, label the dendrogram at leaves with column or row names
    metric : str, optional
        Distance metric. Anything valid for scipy.spatial.distance.pdist
    method : str, optional
        Linkage method to use. Anything valid for
        scipy.cluster.hierarchy.linkage
    rotate : bool, optional
        When plotting the matrix, whether to rotate it 90 degrees
        counter-clockwise, so the leaves face right
    tree_kws : dict, optional
        Keyword arguments for the ``matplotlib.collections.LineCollection``
        that is used for plotting the lines of the dendrogram tree.
    ax : matplotlib axis, optional
        Axis to plot on, otherwise uses current axis

    Returns
    -------
    dendrogramplotter : _DendrogramPlotter
        A Dendrogram plotter object.

    Notes
    -----
    Access the reordered dendrogram indices with
    dendrogramplotter.reordered_ind

    """
    if _no_scipy:
        raise RuntimeError("dendrogram requires scipy to be installed")

    plotter = _DendrogramPlotter(data, linkage=linkage, axis=axis,
                                 metric=metric, method=method,
                                 label=label, rotate=rotate)
    if ax is None:
        ax = plt.gca()

    return plotter.plot(ax=ax, tree_kws=tree_kws)


class ClusterGrid(Grid):

    def __init__(self, data, pivot_kws=None, z_score=None, standard_scale=None,
                 figsize=None, row_colors=None, col_colors=None, mask=None,
                 dendrogram_ratio=None, colors_ratio=None, cbar_pos=None):
        """Grid object for organizing clustered heatmap input on to axes"""
        if _no_scipy:
            raise RuntimeError("ClusterGrid requires scipy to be available")

        if isinstance(data, pd.DataFrame):
            self.data = data
        else:
            self.data = pd.DataFrame(data)

        self.data2d = self.format_data(self.data, pivot_kws, z_score,
                                       standard_scale)

        self.mask = _matrix_mask(self.data2d, mask)

        self._figure = plt.figure(figsize=figsize)

        self.row_colors, self.row_color_labels = \
            self._preprocess_colors(data, row_colors, axis=0)
        self.col_colors, self.col_color_labels = \
            self._preprocess_colors(data, col_colors, axis=1)

        try:
            row_dendrogram_ratio, col_dendrogram_ratio = dendrogram_ratio
        except TypeError:
            row_dendrogram_ratio = col_dendrogram_ratio = dendrogram_ratio

        try:
            row_colors_ratio, col_colors_ratio = colors_ratio
        except TypeError:
            row_colors_ratio = col_colors_ratio = colors_ratio

        width_ratios = self.dim_ratios(self.row_colors,
                                       row_dendrogram_ratio,
                                       row_colors_ratio)
        height_ratios = self.dim_ratios(self.col_colors,
                                        col_dendrogram_ratio,
                                        col_colors_ratio)

        nrows = 2 if self.col_colors is None else 3
        ncols = 2 if self.row_colors is None else 3

        self.gs = gridspec.GridSpec(nrows, ncols,
                                    width_ratios=width_ratios,
                                    height_ratios=height_ratios)

        self.ax_row_dendrogram = self._figure.add_subplot(self.gs[-1, 0])
        self.ax_col_dendrogram = self._figure.add_subplot(self.gs[0, -1])
        self.ax_row_dendrogram.set_axis_off()
        self.ax_col_dendrogram.set_axis_off()

        self.ax_row_colors = None
        self.ax_col_colors = None

        if self.row_colors is not None:
            self.ax_row_colors = self._figure.add_subplot(
                self.gs[-1, 1])
        if self.col_colors is not None:
            self.ax_col_colors = self._figure.add_subplot(
                self.gs[1, -1])

        self.ax_heatmap = self._figure.add_subplot(self.gs[-1, -1])
        if cbar_pos is None:
            self.ax_cbar = self.cax = None
        else:
            # Initialize the colorbar axes in the gridspec so that tight_layout
            # works. We will move it where it belongs later. This is a hack.
            self.ax_cbar = self._figure.add_subplot(self.gs[0, 0])
            self.cax = self.ax_cbar  # Backwards compatibility
        self.cbar_pos = cbar_pos

        self.dendrogram_row = None
        self.dendrogram_col = None

    def _preprocess_colors(self, data, colors, axis):
        """Preprocess {row/col}_colors to extract labels and convert colors."""
        labels = None

        if colors is not None:
            if isinstance(colors, (pd.DataFrame, pd.Series)):

                # If data is unindexed, raise
                if (not hasattr(data, "index") and axis == 0) or (
                    not hasattr(data, "columns") and axis == 1
                ):
                    axis_name = "col" if axis else "row"
                    msg = (f"{axis_name}_colors indices can't be matched with data "
                           f"indices. Provide {axis_name}_colors as a non-indexed "
                           "datatype, e.g. by using `.to_numpy()``")
                    raise TypeError(msg)

                # Ensure colors match data indices
                if axis == 0:
                    colors = colors.reindex(data.index)
                else:
                    colors = colors.reindex(data.columns)

                # Replace na's with white color
                # TODO We should set these to transparent instead
                colors = colors.astype(object).fillna('white')

                # Extract color values and labels from frame/series
                if isinstance(colors, pd.DataFrame):
                    labels = list(colors.columns)
                    colors = colors.T.values
                else:
                    if colors.name is None:
                        labels = [""]
                    else:
                        labels = [colors.name]
                    colors = colors.values

            colors = _convert_colors(colors)

        return colors, labels

    def format_data(self, data, pivot_kws, z_score=None,
                    standard_scale=None):
        """Extract variables from data or use directly."""

        # Either the data is already in 2d matrix format, or need to do a pivot
        if pivot_kws is not None:
            data2d = data.pivot(**pivot_kws)
        else:
            data2d = data

        if z_score is not None and standard_scale is not None:
            raise ValueError(
                'Cannot perform both z-scoring and standard-scaling on data')

        if z_score is not None:
            data2d = self.z_score(data2d, z_score)
        if standard_scale is not None:
            data2d = self.standard_scale(data2d, standard_scale)
        return data2d

    @staticmethod
    def z_score(data2d, axis=1):
        """Standarize the mean and variance of the data axis

        Parameters
        ----------
        data2d : pandas.DataFrame
            Data to normalize
        axis : int
            Which axis to normalize across. If 0, normalize across rows, if 1,
            normalize across columns.

        Returns
        -------
        normalized : pandas.DataFrame
            Noramlized data with a mean of 0 and variance of 1 across the
            specified axis.
        """
        if axis == 1:
            z_scored = data2d
        else:
            z_scored = data2d.T

        z_scored = (z_scored - z_scored.mean()) / z_scored.std()

        if axis == 1:
            return z_scored
        else:
            return z_scored.T

    @staticmethod
    def standard_scale(data2d, axis=1):
        """Divide the data by the difference between the max and min

        Parameters
        ----------
        data2d : pandas.DataFrame
            Data to normalize
        axis : int
            Which axis to normalize across. If 0, normalize across rows, if 1,
            normalize across columns.

        Returns
        -------
        standardized : pandas.DataFrame
            Noramlized data with a mean of 0 and variance of 1 across the
            specified axis.

        """
        # Normalize these values to range from 0 to 1
        if axis == 1:
            standardized = data2d
        else:
            standardized = data2d.T

        subtract = standardized.min()
        standardized = (standardized - subtract) / (
            standardized.max() - standardized.min())

        if axis == 1:
            return standardized
        else:
            return standardized.T

    def dim_ratios(self, colors, dendrogram_ratio, colors_ratio):
        """Get the proportions of the figure taken up by each axes."""
        ratios = [dendrogram_ratio]

        if colors is not None:
            # Colors are encoded as rgb, so there is an extra dimension
            if np.ndim(colors) > 2:
                n_colors = len(colors)
            else:
                n_colors = 1

            ratios += [n_colors * colors_ratio]

        # Add the ratio for the heatmap itself
        ratios.append(1 - sum(ratios))

        return ratios

    @staticmethod
    def color_list_to_matrix_and_cmap(colors, ind, axis=0):
        """Turns a list of colors into a numpy matrix and matplotlib colormap

        These arguments can now be plotted using heatmap(matrix, cmap)
        and the provided colors will be plotted.

        Parameters
        ----------
        colors : list of matplotlib colors
            Colors to label the rows or columns of a dataframe.
        ind : list of ints
            Ordering of the rows or columns, to reorder the original colors
            by the clustered dendrogram order
        axis : int
            Which axis this is labeling

        Returns
        -------
        matrix : numpy.array
            A numpy array of integer values, where each indexes into the cmap
        cmap : matplotlib.colors.ListedColormap

        """
        try:
            mpl.colors.to_rgb(colors[0])
        except ValueError:
            # We have a 2D color structure
            m, n = len(colors), len(colors[0])
            if not all(len(c) == n for c in colors[1:]):
                raise ValueError("Multiple side color vectors must have same size")
        else:
            # We have one vector of colors
            m, n = 1, len(colors)
            colors = [colors]

        # Map from unique colors to colormap index value
        unique_colors = {}
        matrix = np.zeros((m, n), int)
        for i, inner in enumerate(colors):
            for j, color in enumerate(inner):
                idx = unique_colors.setdefault(color, len(unique_colors))
                matrix[i, j] = idx

        # Reorder for clustering and transpose for axis
        matrix = matrix[:, ind]
        if axis == 0:
            matrix = matrix.T

        cmap = mpl.colors.ListedColormap(list(unique_colors))
        return matrix, cmap

    def plot_dendrograms(self, row_cluster, col_cluster, metric, method,
                         row_linkage, col_linkage, tree_kws):
        # Plot the row dendrogram
        if row_cluster:
            self.dendrogram_row = dendrogram(
                self.data2d, metric=metric, method=method, label=False, axis=0,
                ax=self.ax_row_dendrogram, rotate=True, linkage=row_linkage,
                tree_kws=tree_kws
            )
        else:
            self.ax_row_dendrogram.set_xticks([])
            self.ax_row_dendrogram.set_yticks([])
        # PLot the column dendrogram
        if col_cluster:
            self.dendrogram_col = dendrogram(
                self.data2d, metric=metric, method=method, label=False,
                axis=1, ax=self.ax_col_dendrogram, linkage=col_linkage,
                tree_kws=tree_kws
            )
        else:
            self.ax_col_dendrogram.set_xticks([])
            self.ax_col_dendrogram.set_yticks([])
        despine(ax=self.ax_row_dendrogram, bottom=True, left=True)
        despine(ax=self.ax_col_dendrogram, bottom=True, left=True)

    def plot_colors(self, xind, yind, **kws):
        """Plots color labels between the dendrogram and the heatmap

        Parameters
        ----------
        heatmap_kws : dict
            Keyword arguments heatmap

        """
        # Remove any custom colormap and centering
        # TODO this code has consistently caused problems when we
        # have missed kwargs that need to be excluded that it might
        # be better to rewrite *in*clusively.
        kws = kws.copy()
        kws.pop('cmap', None)
        kws.pop('norm', None)
        kws.pop('center', None)
        kws.pop('annot', None)
        kws.pop('vmin', None)
        kws.pop('vmax', None)
        kws.pop('robust', None)
        kws.pop('xticklabels', None)
        kws.pop('yticklabels', None)

        # Plot the row colors
        if self.row_colors is not None:
            matrix, cmap = self.color_list_to_matrix_and_cmap(
                self.row_colors, yind, axis=0)

            # Get row_color labels
            if self.row_color_labels is not None:
                row_color_labels = self.row_color_labels
            else:
                row_color_labels = False

            heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_row_colors,
                    xticklabels=row_color_labels, yticklabels=False, **kws)

            # Adjust rotation of labels
            if row_color_labels is not False:
                plt.setp(self.ax_row_colors.get_xticklabels(), rotation=90)
        else:
            despine(self.ax_row_colors, left=True, bottom=True)

        # Plot the column colors
        if self.col_colors is not None:
            matrix, cmap = self.color_list_to_matrix_and_cmap(
                self.col_colors, xind, axis=1)

            # Get col_color labels
            if self.col_color_labels is not None:
                col_color_labels = self.col_color_labels
            else:
                col_color_labels = False

            heatmap(matrix, cmap=cmap, cbar=False, ax=self.ax_col_colors,
                    xticklabels=False, yticklabels=col_color_labels, **kws)

            # Adjust rotation of labels, place on right side
            if col_color_labels is not False:
                self.ax_col_colors.yaxis.tick_right()
                plt.setp(self.ax_col_colors.get_yticklabels(), rotation=0)
        else:
            despine(self.ax_col_colors, left=True, bottom=True)

    def plot_matrix(self, colorbar_kws, xind, yind, **kws):
        self.data2d = self.data2d.iloc[yind, xind]
        self.mask = self.mask.iloc[yind, xind]

        # Try to reorganize specified tick labels, if provided
        xtl = kws.pop("xticklabels", "auto")
        try:
            xtl = np.asarray(xtl)[xind]
        except (TypeError, IndexError):
            pass
        ytl = kws.pop("yticklabels", "auto")
        try:
            ytl = np.asarray(ytl)[yind]
        except (TypeError, IndexError):
            pass

        # Reorganize the annotations to match the heatmap
        annot = kws.pop("annot", None)
        if annot is None or annot is False:
            pass
        else:
            if isinstance(annot, bool):
                annot_data = self.data2d
            else:
                annot_data = np.asarray(annot)
                if annot_data.shape != self.data2d.shape:
                    err = "`data` and `annot` must have same shape."
                    raise ValueError(err)
                annot_data = annot_data[yind][:, xind]
            annot = annot_data

        # Setting ax_cbar=None in clustermap call implies no colorbar
        kws.setdefault("cbar", self.ax_cbar is not None)
        heatmap(self.data2d, ax=self.ax_heatmap, cbar_ax=self.ax_cbar,
                cbar_kws=colorbar_kws, mask=self.mask,
                xticklabels=xtl, yticklabels=ytl, annot=annot, **kws)

        ytl = self.ax_heatmap.get_yticklabels()
        ytl_rot = None if not ytl else ytl[0].get_rotation()
        self.ax_heatmap.yaxis.set_ticks_position('right')
        self.ax_heatmap.yaxis.set_label_position('right')
        if ytl_rot is not None:
            ytl = self.ax_heatmap.get_yticklabels()
            plt.setp(ytl, rotation=ytl_rot)

        tight_params = dict(h_pad=.02, w_pad=.02)
        if self.ax_cbar is None:
            self._figure.tight_layout(**tight_params)
        else:
            # Turn the colorbar axes off for tight layout so that its
            # ticks don't interfere with the rest of the plot layout.
            # Then move it.
            self.ax_cbar.set_axis_off()
            self._figure.tight_layout(**tight_params)
            self.ax_cbar.set_axis_on()
            self.ax_cbar.set_position(self.cbar_pos)

    def plot(self, metric, method, colorbar_kws, row_cluster, col_cluster,
             row_linkage, col_linkage, tree_kws, **kws):

        # heatmap square=True sets the aspect ratio on the axes, but that is
        # not compatible with the multi-axes layout of clustergrid
        if kws.get("square", False):
            msg = "``square=True`` ignored in clustermap"
            warnings.warn(msg)
            kws.pop("square")

        colorbar_kws = {} if colorbar_kws is None else colorbar_kws

        self.plot_dendrograms(row_cluster, col_cluster, metric, method,
                              row_linkage=row_linkage, col_linkage=col_linkage,
                              tree_kws=tree_kws)
        try:
            xind = self.dendrogram_col.reordered_ind
        except AttributeError:
            xind = np.arange(self.data2d.shape[1])
        try:
            yind = self.dendrogram_row.reordered_ind
        except AttributeError:
            yind = np.arange(self.data2d.shape[0])

        self.plot_colors(xind, yind, **kws)
        self.plot_matrix(colorbar_kws, xind, yind, **kws)
        return self


def clustermap(
    data, *,
    pivot_kws=None, method='average', metric='euclidean',
    z_score=None, standard_scale=None, figsize=(10, 10),
    cbar_kws=None, row_cluster=True, col_cluster=True,
    row_linkage=None, col_linkage=None,
    row_colors=None, col_colors=None, mask=None,
    dendrogram_ratio=.2, colors_ratio=0.03,
    cbar_pos=(.02, .8, .05, .18), tree_kws=None,
    **kwargs
):
    """
    Plot a matrix dataset as a hierarchically-clustered heatmap.

    This function requires scipy to be available.

    Parameters
    ----------
    data : 2D array-like
        Rectangular data for clustering. Cannot contain NAs.
    pivot_kws : dict, optional
        If `data` is a tidy dataframe, can provide keyword arguments for
        pivot to create a rectangular dataframe.
    method : str, optional
        Linkage method to use for calculating clusters. See
        :func:`scipy.cluster.hierarchy.linkage` documentation for more
        information.
    metric : str, optional
        Distance metric to use for the data. See
        :func:`scipy.spatial.distance.pdist` documentation for more options.
        To use different metrics (or methods) for rows and columns, you may
        construct each linkage matrix yourself and provide them as
        `{row,col}_linkage`.
    z_score : int or None, optional
        Either 0 (rows) or 1 (columns). Whether or not to calculate z-scores
        for the rows or the columns. Z scores are: z = (x - mean)/std, so
        values in each row (column) will get the mean of the row (column)
        subtracted, then divided by the standard deviation of the row (column).
        This ensures that each row (column) has mean of 0 and variance of 1.
    standard_scale : int or None, optional
        Either 0 (rows) or 1 (columns). Whether or not to standardize that
        dimension, meaning for each row or column, subtract the minimum and
        divide each by its maximum.
    figsize : tuple of (width, height), optional
        Overall size of the figure.
    cbar_kws : dict, optional
        Keyword arguments to pass to `cbar_kws` in :func:`heatmap`, e.g. to
        add a label to the colorbar.
    {row,col}_cluster : bool, optional
        If ``True``, cluster the {rows, columns}.
    {row,col}_linkage : :class:`numpy.ndarray`, optional
        Precomputed linkage matrix for the rows or columns. See
        :func:`scipy.cluster.hierarchy.linkage` for specific formats.
    {row,col}_colors : list-like or pandas DataFrame/Series, optional
        List of colors to label for either the rows or columns. Useful to evaluate
        whether samples within a group are clustered together. Can use nested lists or
        DataFrame for multiple color levels of labeling. If given as a
        :class:`pandas.DataFrame` or :class:`pandas.Series`, labels for the colors are
        extracted from the DataFrames column names or from the name of the Series.
        DataFrame/Series colors are also matched to the data by their index, ensuring
        colors are drawn in the correct order.
    mask : bool array or DataFrame, optional
        If passed, data will not be shown in cells where `mask` is True.
        Cells with missing values are automatically masked. Only used for
        visualizing, not for calculating.
    {dendrogram,colors}_ratio : float, or pair of floats, optional
        Proportion of the figure size devoted to the two marginal elements. If
        a pair is given, they correspond to (row, col) ratios.
    cbar_pos : tuple of (left, bottom, width, height), optional
        Position of the colorbar axes in the figure. Setting to ``None`` will
        disable the colorbar.
    tree_kws : dict, optional
        Parameters for the :class:`matplotlib.collections.LineCollection`
        that is used to plot the lines of the dendrogram tree.
    kwargs : other keyword arguments
        All other keyword arguments are passed to :func:`heatmap`.

    Returns
    -------
    :class:`ClusterGrid`
        A :class:`ClusterGrid` instance.

    See Also
    --------
    heatmap : Plot rectangular data as a color-encoded matrix.

    Notes
    -----
    The returned object has a ``savefig`` method that should be used if you
    want to save the figure object without clipping the dendrograms.

    To access the reordered row indices, use:
    ``clustergrid.dendrogram_row.reordered_ind``

    Column indices, use:
    ``clustergrid.dendrogram_col.reordered_ind``

    Examples
    --------

    .. include:: ../docstrings/clustermap.rst

    """
    if _no_scipy:
        raise RuntimeError("clustermap requires scipy to be available")

    plotter = ClusterGrid(data, pivot_kws=pivot_kws, figsize=figsize,
                          row_colors=row_colors, col_colors=col_colors,
                          z_score=z_score, standard_scale=standard_scale,
                          mask=mask, dendrogram_ratio=dendrogram_ratio,
                          colors_ratio=colors_ratio, cbar_pos=cbar_pos)

    return plotter.plot(metric=metric, method=method,
                        colorbar_kws=cbar_kws,
                        row_cluster=row_cluster, col_cluster=col_cluster,
                        row_linkage=row_linkage, col_linkage=col_linkage,
                        tree_kws=tree_kws, **kwargs)


================================================
FILE: seaborn/miscplot.py
================================================
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

__all__ = ["palplot", "dogplot"]


def palplot(pal, size=1):
    """Plot the values in a color palette as a horizontal array.

    Parameters
    ----------
    pal : sequence of matplotlib colors
        colors, i.e. as returned by seaborn.color_palette()
    size :
        scaling factor for size of plot

    """
    n = len(pal)
    _, ax = plt.subplots(1, 1, figsize=(n * size, size))
    ax.imshow(np.arange(n).reshape(1, n),
              cmap=mpl.colors.ListedColormap(list(pal)),
              interpolation="nearest", aspect="auto")
    ax.set_xticks(np.arange(n) - .5)
    ax.set_yticks([-.5, .5])
    # Ensure nice border between colors
    ax.set_xticklabels(["" for _ in range(n)])
    # The proper way to set no ticks
    ax.yaxis.set_major_locator(ticker.NullLocator())


def dogplot(*_, **__):
    """Who's a good boy?"""
    from urllib.request import urlopen
    from io import BytesIO

    url = "https://github.com/mwaskom/seaborn-data/raw/master/png/img{}.png"
    pic = np.random.randint(2, 7)
    data = BytesIO(urlopen(url.format(pic)).read())
    img = plt.imread(data)
    f, ax = plt.subplots(figsize=(5, 5), dpi=100)
    f.subplots_adjust(0, 0, 1, 1)
    ax.imshow(img)
    ax.set_axis_off()


================================================
FILE: seaborn/objects.py
================================================
"""
A declarative, object-oriented interface for creating statistical graphics.

The seaborn.objects namespace contains a number of classes that can be composed
together to build a customized visualization.

The main object is :class:`Plot`, which is the starting point for all figures.
Pass :class:`Plot` a dataset and specify assignments from its variables to
roles in the plot. Build up the visualization by calling its methods.

There are four other general types of objects in this interface:

- :class:`Mark` subclasses, which create matplotlib artists for visualization
- :class:`Stat` subclasses, which apply statistical transforms before plotting
- :class:`Move` subclasses, which make further adjustments to reduce overplotting

These classes are passed to :meth:`Plot.add` to define a layer in the plot.
Each layer has a :class:`Mark` and optional :class:`Stat` and/or :class:`Move`.
Plots can have multiple layers.

The other general type of object is a :class:`Scale` subclass, which provide an
interface for controlling the mappings between data values and visual properties.
Pass :class:`Scale` objects to :meth:`Plot.scale`.

See the documentation for other :class:`Plot` methods to learn about the many
ways that a plot can be enhanced and customized.

"""
from seaborn._core.plot import Plot  # noqa: F401

from seaborn._marks.base import Mark  # noqa: F401
from seaborn._marks.area import Area, Band  # noqa: F401
from seaborn._marks.bar import Bar, Bars  # noqa: F401
from seaborn._marks.dot import Dot, Dots  # noqa: F401
from seaborn._marks.line import Dash, Line, Lines, Path, Paths, Range  # noqa: F401
from seaborn._marks.text import Text  # noqa: F401

from seaborn._stats.base import Stat  # noqa: F401
from seaborn._stats.aggregation import Agg, Est  # noqa: F401
from seaborn._stats.counting import Count, Hist  # noqa: F401
from seaborn._stats.density import KDE  # noqa: F401
from seaborn._stats.order import Perc  # noqa: F401
from seaborn._stats.regression import PolyFit  # noqa: F401

from seaborn._core.moves import Dodge, Jitter, Norm, Shift, Stack, Move  # noqa: F401

from seaborn._core.scales import (  # noqa: F401
    Boolean, Continuous, Nominal, Temporal, Scale
)


================================================
FILE: seaborn/palettes.py
================================================
import colorsys
from itertools import cycle

import numpy as np
import matplotlib as mpl

from .external import husl

from .utils import desaturate, get_color_cycle
from .colors import xkcd_rgb, crayons
from ._compat import get_colormap


__all__ = ["color_palette", "hls_palette", "husl_palette", "mpl_palette",
           "dark_palette", "light_palette", "diverging_palette",
           "blend_palette", "xkcd_palette", "crayon_palette",
           "cubehelix_palette", "set_color_codes"]


SEABORN_PALETTES = dict(
    deep=["#4C72B0", "#DD8452", "#55A868", "#C44E52", "#8172B3",
          "#937860", "#DA8BC3", "#8C8C8C", "#CCB974", "#64B5CD"],
    deep6=["#4C72B0", "#55A868", "#C44E52",
           "#8172B3", "#CCB974", "#64B5CD"],
    muted=["#4878D0", "#EE854A", "#6ACC64", "#D65F5F", "#956CB4",
           "#8C613C", "#DC7EC0", "#797979", "#D5BB67", "#82C6E2"],
    muted6=["#4878D0", "#6ACC64", "#D65F5F",
            "#956CB4", "#D5BB67", "#82C6E2"],
    pastel=["#A1C9F4", "#FFB482", "#8DE5A1", "#FF9F9B", "#D0BBFF",
            "#DEBB9B", "#FAB0E4", "#CFCFCF", "#FFFEA3", "#B9F2F0"],
    pastel6=["#A1C9F4", "#8DE5A1", "#FF9F9B",
             "#D0BBFF", "#FFFEA3", "#B9F2F0"],
    bright=["#023EFF", "#FF7C00", "#1AC938", "#E8000B", "#8B2BE2",
            "#9F4800", "#F14CC1", "#A3A3A3", "#FFC400", "#00D7FF"],
    bright6=["#023EFF", "#1AC938", "#E8000B",
             "#8B2BE2", "#FFC400", "#00D7FF"],
    dark=["#001C7F", "#B1400D", "#12711C", "#8C0800", "#591E71",
          "#592F0D", "#A23582", "#3C3C3C", "#B8850A", "#006374"],
    dark6=["#001C7F", "#12711C", "#8C0800",
           "#591E71", "#B8850A", "#006374"],
    colorblind=["#0173B2", "#DE8F05", "#029E73", "#D55E00", "#CC78BC",
                "#CA9161", "#FBAFE4", "#949494", "#ECE133", "#56B4E9"],
    colorblind6=["#0173B2", "#029E73", "#D55E00",
                 "#CC78BC", "#ECE133", "#56B4E9"]
)


MPL_QUAL_PALS = {
    "tab10": 10, "tab20": 20, "tab20b": 20, "tab20c": 20,
    "Set1": 9, "Set2": 8, "Set3": 12,
    "Accent": 8, "Paired": 12,
    "Pastel1": 9, "Pastel2": 8, "Dark2": 8,
}


QUAL_PALETTE_SIZES = MPL_QUAL_PALS.copy()
QUAL_PALETTE_SIZES.update({k: len(v) for k, v in SEABORN_PALETTES.items()})
QUAL_PALETTES = list(QUAL_PALETTE_SIZES.keys())


class _ColorPalette(list):
    """Set the color palette in a with statement, otherwise be a list."""
    def __enter__(self):
        """Open the context."""
        from .rcmod import set_palette
        self._orig_palette = color_palette()
        set_palette(self)
        return self

    def __exit__(self, *args):
        """Close the context."""
        from .rcmod import set_palette
        set_palette(self._orig_palette)

    def as_hex(self):
        """Return a color palette with hex codes instead of RGB values."""
        hex = [mpl.colors.rgb2hex(rgb) for rgb in self]
        return _ColorPalette(hex)

    def _repr_html_(self):
        """Rich display of the color palette in an HTML frontend."""
        s = 55
        n = len(self)
        html = f''
        for i, c in enumerate(self.as_hex()):
            html += (
                f''
            )
        html += ''
        return html


def _patch_colormap_display():
    """Simplify the rich display of matplotlib color maps in a notebook."""
    def _repr_png_(self):
        """Generate a PNG representation of the Colormap."""
        import io
        from PIL import Image
        import numpy as np
        IMAGE_SIZE = (400, 50)
        X = np.tile(np.linspace(0, 1, IMAGE_SIZE[0]), (IMAGE_SIZE[1], 1))
        pixels = self(X, bytes=True)
        png_bytes = io.BytesIO()
        Image.fromarray(pixels).save(png_bytes, format='png')
        return png_bytes.getvalue()

    def _repr_html_(self):
        """Generate an HTML representation of the Colormap."""
        import base64
        png_bytes = self._repr_png_()
        png_base64 = base64.b64encode(png_bytes).decode('ascii')
        return ('')

    mpl.colors.Colormap._repr_png_ = _repr_png_
    mpl.colors.Colormap._repr_html_ = _repr_html_


def color_palette(palette=None, n_colors=None, desat=None, as_cmap=False):
    """Return a list of colors or continuous colormap defining a palette.

    Possible ``palette`` values include:
        - Name of a seaborn palette (deep, muted, bright, pastel, dark, colorblind)
        - Name of matplotlib colormap
        - 'husl' or 'hls'
        - 'ch:'
        - 'light:', 'dark:', 'blend:,',
        - A sequence of colors in any format matplotlib accepts

    Calling this function with ``palette=None`` will return the current
    matplotlib color cycle.

    This function can also be used in a ``with`` statement to temporarily
    set the color cycle for a plot or set of plots.

    See the :ref:`tutorial ` for more information.

    Parameters
    ----------
    palette : None, string, or sequence, optional
        Name of palette or None to return current palette. If a sequence, input
        colors are used but possibly cycled and desaturated.
    n_colors : int, optional
        Number of colors in the palette. If ``None``, the default will depend
        on how ``palette`` is specified. Named palettes default to 6 colors,
        but grabbing the current palette or passing in a list of colors will
        not change the number of colors unless this is specified. Asking for
        more colors than exist in the palette will cause it to cycle. Ignored
        when ``as_cmap`` is True.
    desat : float, optional
        Proportion to desaturate each color by.
    as_cmap : bool
        If True, return a :class:`matplotlib.colors.ListedColormap`.

    Returns
    -------
    list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    See Also
    --------
    set_palette : Set the default color cycle for all plots.
    set_color_codes : Reassign color codes like ``"b"``, ``"g"``, etc. to
                      colors from one of the seaborn palettes.

    Examples
    --------

    .. include:: ../docstrings/color_palette.rst

    """
    if palette is None:
        palette = get_color_cycle()
        if n_colors is None:
            n_colors = len(palette)

    elif not isinstance(palette, str):
        palette = palette
        if n_colors is None:
            n_colors = len(palette)
    else:

        if n_colors is None:
            # Use all colors in a qualitative palette or 6 of another kind
            n_colors = QUAL_PALETTE_SIZES.get(palette, 6)

        if palette in SEABORN_PALETTES:
            # Named "seaborn variant" of matplotlib default color cycle
            palette = SEABORN_PALETTES[palette]

        elif palette == "hls":
            # Evenly spaced colors in cylindrical RGB space
            palette = hls_palette(n_colors, as_cmap=as_cmap)

        elif palette == "husl":
            # Evenly spaced colors in cylindrical Lab space
            palette = husl_palette(n_colors, as_cmap=as_cmap)

        elif palette.lower() == "jet":
            # Paternalism
            raise ValueError("No.")

        elif palette.startswith("ch:"):
            # Cubehelix palette with params specified in string
            args, kwargs = _parse_cubehelix_args(palette)
            palette = cubehelix_palette(n_colors, *args, **kwargs, as_cmap=as_cmap)

        elif palette.startswith("light:"):
            # light palette to color specified in string
            _, color = palette.split(":")
            reverse = color.endswith("_r")
            if reverse:
                color = color[:-2]
            palette = light_palette(color, n_colors, reverse=reverse, as_cmap=as_cmap)

        elif palette.startswith("dark:"):
            # light palette to color specified in string
            _, color = palette.split(":")
            reverse = color.endswith("_r")
            if reverse:
                color = color[:-2]
            palette = dark_palette(color, n_colors, reverse=reverse, as_cmap=as_cmap)

        elif palette.startswith("blend:"):
            # blend palette between colors specified in string
            _, colors = palette.split(":")
            colors = colors.split(",")
            palette = blend_palette(colors, n_colors, as_cmap=as_cmap)

        else:
            try:
                # Perhaps a named matplotlib colormap?
                palette = mpl_palette(palette, n_colors, as_cmap=as_cmap)
            except (ValueError, KeyError):  # Error class changed in mpl36
                raise ValueError(f"{palette!r} is not a valid palette name")

    if desat is not None:
        palette = [desaturate(c, desat) for c in palette]

    if not as_cmap:

        # Always return as many colors as we asked for
        pal_cycle = cycle(palette)
        palette = [next(pal_cycle) for _ in range(n_colors)]

        # Always return in r, g, b tuple format
        try:
            palette = map(mpl.colors.colorConverter.to_rgb, palette)
            palette = _ColorPalette(palette)
        except ValueError:
            raise ValueError(f"Could not generate a palette for {palette}")

    return palette


def hls_palette(n_colors=6, h=.01, l=.6, s=.65, as_cmap=False):  # noqa
    """
    Return hues with constant lightness and saturation in the HLS system.

    The hues are evenly sampled along a circular path. The resulting palette will be
    appropriate for categorical or cyclical data.

    The `h`, `l`, and `s` values should be between 0 and 1.

    .. note::
        While the separation of the resulting colors will be mathematically
        constant, the HLS system does not construct a perceptually-uniform space,
        so their apparent intensity will vary.

    Parameters
    ----------
    n_colors : int
        Number of colors in the palette.
    h : float
        The value of the first hue.
    l : float
        The lightness value.
    s : float
        The saturation intensity.
    as_cmap : bool
        If True, return a matplotlib colormap object.

    Returns
    -------
    palette
        list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    See Also
    --------
    husl_palette : Make a palette using evenly spaced hues in the HUSL system.

    Examples
    --------
    .. include:: ../docstrings/hls_palette.rst

    """
    if as_cmap:
        n_colors = 256
    hues = np.linspace(0, 1, int(n_colors) + 1)[:-1]
    hues += h
    hues %= 1
    hues -= hues.astype(int)
    palette = [colorsys.hls_to_rgb(h_i, l, s) for h_i in hues]
    if as_cmap:
        return mpl.colors.ListedColormap(palette, "hls")
    else:
        return _ColorPalette(palette)


def husl_palette(n_colors=6, h=.01, s=.9, l=.65, as_cmap=False):  # noqa
    """
    Return hues with constant lightness and saturation in the HUSL system.

    The hues are evenly sampled along a circular path. The resulting palette will be
    appropriate for categorical or cyclical data.

    The `h`, `l`, and `s` values should be between 0 and 1.

    This function is similar to :func:`hls_palette`, but it uses a nonlinear color
    space that is more perceptually uniform.

    Parameters
    ----------
    n_colors : int
        Number of colors in the palette.
    h : float
        The value of the first hue.
    l : float
        The lightness value.
    s : float
        The saturation intensity.
    as_cmap : bool
        If True, return a matplotlib colormap object.

    Returns
    -------
    palette
        list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    See Also
    --------
    hls_palette : Make a palette using evenly spaced hues in the HSL system.

    Examples
    --------
    .. include:: ../docstrings/husl_palette.rst

    """
    if as_cmap:
        n_colors = 256
    hues = np.linspace(0, 1, int(n_colors) + 1)[:-1]
    hues += h
    hues %= 1
    hues *= 359
    s *= 99
    l *= 99  # noqa
    palette = [_color_to_rgb((h_i, s, l), input="husl") for h_i in hues]
    if as_cmap:
        return mpl.colors.ListedColormap(palette, "hsl")
    else:
        return _ColorPalette(palette)


def mpl_palette(name, n_colors=6, as_cmap=False):
    """
    Return a palette or colormap from the matplotlib registry.

    For continuous palettes, evenly-spaced discrete samples are chosen while
    excluding the minimum and maximum value in the colormap to provide better
    contrast at the extremes.

    For qualitative palettes (e.g. those from colorbrewer), exact values are
    indexed (rather than interpolated), but fewer than `n_colors` can be returned
    if the palette does not define that many.

    Parameters
    ----------
    name : string
        Name of the palette. This should be a named matplotlib colormap.
    n_colors : int
        Number of discrete colors in the palette.

    Returns
    -------
    list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    Examples
    --------
    .. include:: ../docstrings/mpl_palette.rst

    """
    if name.endswith("_d"):
        sub_name = name[:-2]
        if sub_name.endswith("_r"):
            reverse = True
            sub_name = sub_name[:-2]
        else:
            reverse = False
        pal = color_palette(sub_name, 2) + ["#333333"]
        if reverse:
            pal = pal[::-1]
        cmap = blend_palette(pal, n_colors, as_cmap=True)
    else:
        cmap = get_colormap(name)

    if name in MPL_QUAL_PALS:
        bins = np.linspace(0, 1, MPL_QUAL_PALS[name])[:n_colors]
    else:
        bins = np.linspace(0, 1, int(n_colors) + 2)[1:-1]
    palette = list(map(tuple, cmap(bins)[:, :3]))

    if as_cmap:
        return cmap
    else:
        return _ColorPalette(palette)


def _color_to_rgb(color, input):
    """Add some more flexibility to color choices."""
    if input == "hls":
        color = colorsys.hls_to_rgb(*color)
    elif input == "husl":
        color = husl.husl_to_rgb(*color)
        color = tuple(np.clip(color, 0, 1))
    elif input == "xkcd":
        color = xkcd_rgb[color]

    return mpl.colors.to_rgb(color)


def dark_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"):
    """Make a sequential palette that blends from dark to ``color``.

    This kind of palette is good for data that range between relatively
    uninteresting low values and interesting high values.

    The ``color`` parameter can be specified in a number of ways, including
    all options for defining a color in matplotlib and several additional
    color spaces that are handled by seaborn. You can also use the database
    of named colors from the XKCD color survey.

    If you are using the IPython notebook, you can also choose this palette
    interactively with the :func:`choose_dark_palette` function.

    Parameters
    ----------
    color : base color for high values
        hex, rgb-tuple, or html color name
    n_colors : int, optional
        number of colors in the palette
    reverse : bool, optional
        if True, reverse the direction of the blend
    as_cmap : bool, optional
        If True, return a :class:`matplotlib.colors.ListedColormap`.
    input : {'rgb', 'hls', 'husl', xkcd'}
        Color space to interpret the input color. The first three options
        apply to tuple inputs and the latter applies to string inputs.

    Returns
    -------
    palette
        list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    See Also
    --------
    light_palette : Create a sequential palette with bright low values.
    diverging_palette : Create a diverging palette with two colors.

    Examples
    --------
    .. include:: ../docstrings/dark_palette.rst

    """
    rgb = _color_to_rgb(color, input)
    hue, sat, _ = husl.rgb_to_husl(*rgb)
    gray_s, gray_l = .15 * sat, 15
    gray = _color_to_rgb((hue, gray_s, gray_l), input="husl")
    colors = [rgb, gray] if reverse else [gray, rgb]
    return blend_palette(colors, n_colors, as_cmap)


def light_palette(color, n_colors=6, reverse=False, as_cmap=False, input="rgb"):
    """Make a sequential palette that blends from light to ``color``.

    The ``color`` parameter can be specified in a number of ways, including
    all options for defining a color in matplotlib and several additional
    color spaces that are handled by seaborn. You can also use the database
    of named colors from the XKCD color survey.

    If you are using a Jupyter notebook, you can also choose this palette
    interactively with the :func:`choose_light_palette` function.

    Parameters
    ----------
    color : base color for high values
        hex code, html color name, or tuple in `input` space.
    n_colors : int, optional
        number of colors in the palette
    reverse : bool, optional
        if True, reverse the direction of the blend
    as_cmap : bool, optional
        If True, return a :class:`matplotlib.colors.ListedColormap`.
    input : {'rgb', 'hls', 'husl', xkcd'}
        Color space to interpret the input color. The first three options
        apply to tuple inputs and the latter applies to string inputs.

    Returns
    -------
    palette
        list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    See Also
    --------
    dark_palette : Create a sequential palette with dark low values.
    diverging_palette : Create a diverging palette with two colors.

    Examples
    --------
    .. include:: ../docstrings/light_palette.rst

    """
    rgb = _color_to_rgb(color, input)
    hue, sat, _ = husl.rgb_to_husl(*rgb)
    gray_s, gray_l = .15 * sat, 95
    gray = _color_to_rgb((hue, gray_s, gray_l), input="husl")
    colors = [rgb, gray] if reverse else [gray, rgb]
    return blend_palette(colors, n_colors, as_cmap)


def diverging_palette(h_neg, h_pos, s=75, l=50, sep=1, n=6,  # noqa
                      center="light", as_cmap=False):
    """Make a diverging palette between two HUSL colors.

    If you are using the IPython notebook, you can also choose this palette
    interactively with the :func:`choose_diverging_palette` function.

    Parameters
    ----------
    h_neg, h_pos : float in [0, 359]
        Anchor hues for negative and positive extents of the map.
    s : float in [0, 100], optional
        Anchor saturation for both extents of the map.
    l : float in [0, 100], optional
        Anchor lightness for both extents of the map.
    sep : int, optional
        Size of the intermediate region.
    n : int, optional
        Number of colors in the palette (if not returning a cmap)
    center : {"light", "dark"}, optional
        Whether the center of the palette is light or dark
    as_cmap : bool, optional
        If True, return a :class:`matplotlib.colors.ListedColormap`.

    Returns
    -------
    palette
        list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    See Also
    --------
    dark_palette : Create a sequential palette with dark values.
    light_palette : Create a sequential palette with light values.

    Examples
    --------
    .. include: ../docstrings/diverging_palette.rst

    """
    palfunc = dict(dark=dark_palette, light=light_palette)[center]
    n_half = int(128 - (sep // 2))
    neg = palfunc((h_neg, s, l), n_half, reverse=True, input="husl")
    pos = palfunc((h_pos, s, l), n_half, input="husl")
    midpoint = dict(light=[(.95, .95, .95)], dark=[(.133, .133, .133)])[center]
    mid = midpoint * sep
    pal = blend_palette(np.concatenate([neg, mid, pos]), n, as_cmap=as_cmap)
    return pal


def blend_palette(colors, n_colors=6, as_cmap=False, input="rgb"):
    """Make a palette that blends between a list of colors.

    Parameters
    ----------
    colors : sequence of colors in various formats interpreted by `input`
        hex code, html color name, or tuple in `input` space.
    n_colors : int, optional
        Number of colors in the palette.
    as_cmap : bool, optional
        If True, return a :class:`matplotlib.colors.ListedColormap`.

    Returns
    -------
    palette
        list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    Examples
    --------
    .. include: ../docstrings/blend_palette.rst

    """
    colors = [_color_to_rgb(color, input) for color in colors]
    name = "blend"
    pal = mpl.colors.LinearSegmentedColormap.from_list(name, colors)
    if not as_cmap:
        rgb_array = pal(np.linspace(0, 1, int(n_colors)))[:, :3]  # no alpha
        pal = _ColorPalette(map(tuple, rgb_array))
    return pal


def xkcd_palette(colors):
    """Make a palette with color names from the xkcd color survey.

    See xkcd for the full list of colors: https://xkcd.com/color/rgb/

    This is just a simple wrapper around the `seaborn.xkcd_rgb` dictionary.

    Parameters
    ----------
    colors : list of strings
        List of keys in the `seaborn.xkcd_rgb` dictionary.

    Returns
    -------
    palette
        A list of colors as RGB tuples.

    See Also
    --------
    crayon_palette : Make a palette with Crayola crayon colors.

    """
    palette = [xkcd_rgb[name] for name in colors]
    return color_palette(palette, len(palette))


def crayon_palette(colors):
    """Make a palette with color names from Crayola crayons.

    Colors are taken from here:
    https://en.wikipedia.org/wiki/List_of_Crayola_crayon_colors

    This is just a simple wrapper around the `seaborn.crayons` dictionary.

    Parameters
    ----------
    colors : list of strings
        List of keys in the `seaborn.crayons` dictionary.

    Returns
    -------
    palette
        A list of colors as RGB tuples.

    See Also
    --------
    xkcd_palette : Make a palette with named colors from the XKCD color survey.

    """
    palette = [crayons[name] for name in colors]
    return color_palette(palette, len(palette))


def cubehelix_palette(n_colors=6, start=0, rot=.4, gamma=1.0, hue=0.8,
                      light=.85, dark=.15, reverse=False, as_cmap=False):
    """Make a sequential palette from the cubehelix system.

    This produces a colormap with linearly-decreasing (or increasing)
    brightness. That means that information will be preserved if printed to
    black and white or viewed by someone who is colorblind.  "cubehelix" is
    also available as a matplotlib-based palette, but this function gives the
    user more control over the look of the palette and has a different set of
    defaults.

    In addition to using this function, it is also possible to generate a
    cubehelix palette generally in seaborn using a string starting with
    `ch:` and containing other parameters (e.g. `"ch:s=.25,r=-.5"`).

    Parameters
    ----------
    n_colors : int
        Number of colors in the palette.
    start : float, 0 <= start <= 3
        The hue value at the start of the helix.
    rot : float
        Rotations around the hue wheel over the range of the palette.
    gamma : float 0 <= gamma
        Nonlinearity to emphasize dark (gamma < 1) or light (gamma > 1) colors.
    hue : float, 0 <= hue <= 1
        Saturation of the colors.
    dark : float 0 <= dark <= 1
        Intensity of the darkest color in the palette.
    light : float 0 <= light <= 1
        Intensity of the lightest color in the palette.
    reverse : bool
        If True, the palette will go from dark to light.
    as_cmap : bool
        If True, return a :class:`matplotlib.colors.ListedColormap`.

    Returns
    -------
    palette
        list of RGB tuples or :class:`matplotlib.colors.ListedColormap`

    See Also
    --------
    choose_cubehelix_palette : Launch an interactive widget to select cubehelix
                               palette parameters.
    dark_palette : Create a sequential palette with dark low values.
    light_palette : Create a sequential palette with bright low values.

    References
    ----------
    Green, D. A. (2011). "A colour scheme for the display of astronomical
    intensity images". Bulletin of the Astromical Society of India, Vol. 39,
    p. 289-295.

    Examples
    --------
    .. include:: ../docstrings/cubehelix_palette.rst

    """
    def get_color_function(p0, p1):
        # Copied from matplotlib because it lives in private module
        def color(x):
            # Apply gamma factor to emphasise low or high intensity values
            xg = x ** gamma

            # Calculate amplitude and angle of deviation from the black
            # to white diagonal in the plane of constant
            # perceived intensity.
            a = hue * xg * (1 - xg) / 2

            phi = 2 * np.pi * (start / 3 + rot * x)

            return xg + a * (p0 * np.cos(phi) + p1 * np.sin(phi))
        return color

    cdict = {
        "red": get_color_function(-0.14861, 1.78277),
        "green": get_color_function(-0.29227, -0.90649),
        "blue": get_color_function(1.97294, 0.0),
    }

    cmap = mpl.colors.LinearSegmentedColormap("cubehelix", cdict)

    x = np.linspace(light, dark, int(n_colors))
    pal = cmap(x)[:, :3].tolist()
    if reverse:
        pal = pal[::-1]

    if as_cmap:
        x_256 = np.linspace(light, dark, 256)
        if reverse:
            x_256 = x_256[::-1]
        pal_256 = cmap(x_256)
        cmap = mpl.colors.ListedColormap(pal_256, "seaborn_cubehelix")
        return cmap
    else:
        return _ColorPalette(pal)


def _parse_cubehelix_args(argstr):
    """Turn stringified cubehelix params into args/kwargs."""

    if argstr.startswith("ch:"):
        argstr = argstr[3:]

    if argstr.endswith("_r"):
        reverse = True
        argstr = argstr[:-2]
    else:
        reverse = False

    if not argstr:
        return [], {"reverse": reverse}

    all_args = argstr.split(",")

    args = [float(a.strip(" ")) for a in all_args if "=" not in a]

    kwargs = [a.split("=") for a in all_args if "=" in a]
    kwargs = {k.strip(" "): float(v.strip(" ")) for k, v in kwargs}

    kwarg_map = dict(
        s="start", r="rot", g="gamma",
        h="hue", l="light", d="dark",  # noqa: E741
    )

    kwargs = {kwarg_map.get(k, k): v for k, v in kwargs.items()}

    if reverse:
        kwargs["reverse"] = True

    return args, kwargs


def set_color_codes(palette="deep"):
    """Change how matplotlib color shorthands are interpreted.

    Calling this will change how shorthand codes like "b" or "g"
    are interpreted by matplotlib in subsequent plots.

    Parameters
    ----------
    palette : {deep, muted, pastel, dark, bright, colorblind}
        Named seaborn palette to use as the source of colors.

    See Also
    --------
    set : Color codes can be set through the high-level seaborn style
          manager.
    set_palette : Color codes can also be set through the function that
                  sets the matplotlib color cycle.

    """
    if palette == "reset":
        colors = [
            (0., 0., 1.),
            (0., .5, 0.),
            (1., 0., 0.),
            (.75, 0., .75),
            (.75, .75, 0.),
            (0., .75, .75),
            (0., 0., 0.)
        ]
    elif not isinstance(palette, str):
        err = "set_color_codes requires a named seaborn palette"
        raise TypeError(err)
    elif palette in SEABORN_PALETTES:
        if not palette.endswith("6"):
            palette = palette + "6"
        colors = SEABORN_PALETTES[palette] + [(.1, .1, .1)]
    else:
        err = f"Cannot set colors with palette '{palette}'"
        raise ValueError(err)

    for code, color in zip("bgrmyck", colors):
        rgb = mpl.colors.colorConverter.to_rgb(color)
        mpl.colors.colorConverter.colors[code] = rgb


================================================
FILE: seaborn/rcmod.py
================================================
"""Control plot style and scaling using the matplotlib rcParams interface."""
import functools
import matplotlib as mpl
from cycler import cycler
from . import palettes


__all__ = ["set_theme", "set", "reset_defaults", "reset_orig",
           "axes_style", "set_style", "plotting_context", "set_context",
           "set_palette"]


_style_keys = [

    "axes.facecolor",
    "axes.edgecolor",
    "axes.grid",
    "axes.axisbelow",
    "axes.labelcolor",

    "figure.facecolor",

    "grid.color",
    "grid.linestyle",

    "text.color",

    "xtick.color",
    "ytick.color",
    "xtick.direction",
    "ytick.direction",
    "lines.solid_capstyle",

    "patch.edgecolor",
    "patch.force_edgecolor",

    "image.cmap",
    "font.family",
    "font.sans-serif",

    "xtick.bottom",
    "xtick.top",
    "ytick.left",
    "ytick.right",

    "axes.spines.left",
    "axes.spines.bottom",
    "axes.spines.right",
    "axes.spines.top",

]

_context_keys = [

    "font.size",
    "axes.labelsize",
    "axes.titlesize",
    "xtick.labelsize",
    "ytick.labelsize",
    "legend.fontsize",
    "legend.title_fontsize",

    "axes.linewidth",
    "grid.linewidth",
    "lines.linewidth",
    "lines.markersize",
    "patch.linewidth",

    "xtick.major.width",
    "ytick.major.width",
    "xtick.minor.width",
    "ytick.minor.width",

    "xtick.major.size",
    "ytick.major.size",
    "xtick.minor.size",
    "ytick.minor.size",

]


def set_theme(context="notebook", style="darkgrid", palette="deep",
              font="sans-serif", font_scale=1, color_codes=True, rc=None):
    """
    Set aspects of the visual theme for all matplotlib and seaborn plots.

    This function changes the global defaults for all plots using the
    matplotlib rcParams system. The themeing is decomposed into several distinct
    sets of parameter values.

    The options are illustrated in the :doc:`aesthetics <../tutorial/aesthetics>`
    and :doc:`color palette <../tutorial/color_palettes>` tutorials.

    Parameters
    ----------
    context : string or dict
        Scaling parameters, see :func:`plotting_context`.
    style : string or dict
        Axes style parameters, see :func:`axes_style`.
    palette : string or sequence
        Color palette, see :func:`color_palette`.
    font : string
        Font family, see matplotlib font manager.
    font_scale : float, optional
        Separate scaling factor to independently scale the size of the
        font elements.
    color_codes : bool
        If ``True`` and ``palette`` is a seaborn palette, remap the shorthand
        color codes (e.g. "b", "g", "r", etc.) to the colors from this palette.
    rc : dict or None
        Dictionary of rc parameter mappings to override the above.

    Examples
    --------

    .. include:: ../docstrings/set_theme.rst

    """
    set_context(context, font_scale)
    set_style(style, rc={"font.family": font})
    set_palette(palette, color_codes=color_codes)
    if rc is not None:
        mpl.rcParams.update(rc)


def set(*args, **kwargs):
    """
    Alias for :func:`set_theme`, which is the preferred interface.

    This function may be removed in the future.
    """
    set_theme(*args, **kwargs)


def reset_defaults():
    """Restore all RC params to default settings."""
    mpl.rcParams.update(mpl.rcParamsDefault)


def reset_orig():
    """Restore all RC params to original settings (respects custom rc)."""
    from . import _orig_rc_params
    mpl.rcParams.update(_orig_rc_params)


def axes_style(style=None, rc=None):
    """
    Get the parameters that control the general style of the plots.

    The style parameters control properties like the color of the background and
    whether a grid is enabled by default. This is accomplished using the
    matplotlib rcParams system.

    The options are illustrated in the
    :doc:`aesthetics tutorial <../tutorial/aesthetics>`.

    This function can also be used as a context manager to temporarily
    alter the global defaults. See :func:`set_theme` or :func:`set_style`
    to modify the global defaults for all plots.

    Parameters
    ----------
    style : None, dict, or one of {darkgrid, whitegrid, dark, white, ticks}
        A dictionary of parameters or the name of a preconfigured style.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        style dictionaries. This only updates parameters that are
        considered part of the style definition.

    Examples
    --------

    .. include:: ../docstrings/axes_style.rst

    """
    if style is None:
        style_dict = {k: mpl.rcParams[k] for k in _style_keys}

    elif isinstance(style, dict):
        style_dict = style

    else:
        styles = ["white", "dark", "whitegrid", "darkgrid", "ticks"]
        if style not in styles:
            raise ValueError(f"style must be one of {', '.join(styles)}")

        # Define colors here
        dark_gray = ".15"
        light_gray = ".8"

        # Common parameters
        style_dict = {

            "figure.facecolor": "white",
            "axes.labelcolor": dark_gray,

            "xtick.direction": "out",
            "ytick.direction": "out",
            "xtick.color": dark_gray,
            "ytick.color": dark_gray,

            "axes.axisbelow": True,
            "grid.linestyle": "-",


            "text.color": dark_gray,
            "font.family": ["sans-serif"],
            "font.sans-serif": ["Arial", "DejaVu Sans", "Liberation Sans",
                                "Bitstream Vera Sans", "sans-serif"],


            "lines.solid_capstyle": "round",
            "patch.edgecolor": "w",
            "patch.force_edgecolor": True,

            "image.cmap": "rocket",

            "xtick.top": False,
            "ytick.right": False,

        }

        # Set grid on or off
        if "grid" in style:
            style_dict.update({
                "axes.grid": True,
            })
        else:
            style_dict.update({
                "axes.grid": False,
            })

        # Set the color of the background, spines, and grids
        if style.startswith("dark"):
            style_dict.update({

                "axes.facecolor": "#EAEAF2",
                "axes.edgecolor": "white",
                "grid.color": "white",

                "axes.spines.left": True,
                "axes.spines.bottom": True,
                "axes.spines.right": True,
                "axes.spines.top": True,

            })

        elif style == "whitegrid":
            style_dict.update({

                "axes.facecolor": "white",
                "axes.edgecolor": light_gray,
                "grid.color": light_gray,

                "axes.spines.left": True,
                "axes.spines.bottom": True,
                "axes.spines.right": True,
                "axes.spines.top": True,

            })

        elif style in ["white", "ticks"]:
            style_dict.update({

                "axes.facecolor": "white",
                "axes.edgecolor": dark_gray,
                "grid.color": light_gray,

                "axes.spines.left": True,
                "axes.spines.bottom": True,
                "axes.spines.right": True,
                "axes.spines.top": True,

            })

        # Show or hide the axes ticks
        if style == "ticks":
            style_dict.update({
                "xtick.bottom": True,
                "ytick.left": True,
            })
        else:
            style_dict.update({
                "xtick.bottom": False,
                "ytick.left": False,
            })

    # Remove entries that are not defined in the base list of valid keys
    # This lets us handle matplotlib <=/> 2.0
    style_dict = {k: v for k, v in style_dict.items() if k in _style_keys}

    # Override these settings with the provided rc dictionary
    if rc is not None:
        rc = {k: v for k, v in rc.items() if k in _style_keys}
        style_dict.update(rc)

    # Wrap in an _AxesStyle object so this can be used in a with statement
    style_object = _AxesStyle(style_dict)

    return style_object


def set_style(style=None, rc=None):
    """
    Set the parameters that control the general style of the plots.

    The style parameters control properties like the color of the background and
    whether a grid is enabled by default. This is accomplished using the
    matplotlib rcParams system.

    The options are illustrated in the
    :doc:`aesthetics tutorial <../tutorial/aesthetics>`.

    See :func:`axes_style` to get the parameter values.

    Parameters
    ----------
    style : dict, or one of {darkgrid, whitegrid, dark, white, ticks}
        A dictionary of parameters or the name of a preconfigured style.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        style dictionaries. This only updates parameters that are
        considered part of the style definition.

    Examples
    --------

    .. include:: ../docstrings/set_style.rst

    """
    style_object = axes_style(style, rc)
    mpl.rcParams.update(style_object)


def plotting_context(context=None, font_scale=1, rc=None):
    """
    Get the parameters that control the scaling of plot elements.

    These parameters correspond to label size, line thickness, etc. For more
    information, see the :doc:`aesthetics tutorial <../tutorial/aesthetics>`.

    The base context is "notebook", and the other contexts are "paper", "talk",
    and "poster", which are version of the notebook parameters scaled by different
    values. Font elements can also be scaled independently of (but relative to)
    the other values.

    This function can also be used as a context manager to temporarily
    alter the global defaults. See :func:`set_theme` or :func:`set_context`
    to modify the global defaults for all plots.

    Parameters
    ----------
    context : None, dict, or one of {paper, notebook, talk, poster}
        A dictionary of parameters or the name of a preconfigured set.
    font_scale : float, optional
        Separate scaling factor to independently scale the size of the
        font elements.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        context dictionaries. This only updates parameters that are
        considered part of the context definition.

    Examples
    --------

    .. include:: ../docstrings/plotting_context.rst

    """
    if context is None:
        context_dict = {k: mpl.rcParams[k] for k in _context_keys}

    elif isinstance(context, dict):
        context_dict = context

    else:

        contexts = ["paper", "notebook", "talk", "poster"]
        if context not in contexts:
            raise ValueError(f"context must be in {', '.join(contexts)}")

        # Set up dictionary of default parameters
        texts_base_context = {

            "font.size": 12,
            "axes.labelsize": 12,
            "axes.titlesize": 12,
            "xtick.labelsize": 11,
            "ytick.labelsize": 11,
            "legend.fontsize": 11,
            "legend.title_fontsize": 12,

        }

        base_context = {

            "axes.linewidth": 1.25,
            "grid.linewidth": 1,
            "lines.linewidth": 1.5,
            "lines.markersize": 6,
            "patch.linewidth": 1,

            "xtick.major.width": 1.25,
            "ytick.major.width": 1.25,
            "xtick.minor.width": 1,
            "ytick.minor.width": 1,

            "xtick.major.size": 6,
            "ytick.major.size": 6,
            "xtick.minor.size": 4,
            "ytick.minor.size": 4,

        }
        base_context.update(texts_base_context)

        # Scale all the parameters by the same factor depending on the context
        scaling = dict(paper=.8, notebook=1, talk=1.5, poster=2)[context]
        context_dict = {k: v * scaling for k, v in base_context.items()}

        # Now independently scale the fonts
        font_keys = texts_base_context.keys()
        font_dict = {k: context_dict[k] * font_scale for k in font_keys}
        context_dict.update(font_dict)

    # Override these settings with the provided rc dictionary
    if rc is not None:
        rc = {k: v for k, v in rc.items() if k in _context_keys}
        context_dict.update(rc)

    # Wrap in a _PlottingContext object so this can be used in a with statement
    context_object = _PlottingContext(context_dict)

    return context_object


def set_context(context=None, font_scale=1, rc=None):
    """
    Set the parameters that control the scaling of plot elements.

    These parameters correspond to label size, line thickness, etc.
    Calling this function modifies the global matplotlib `rcParams`. For more
    information, see the :doc:`aesthetics tutorial <../tutorial/aesthetics>`.

    The base context is "notebook", and the other contexts are "paper", "talk",
    and "poster", which are version of the notebook parameters scaled by different
    values. Font elements can also be scaled independently of (but relative to)
    the other values.

    See :func:`plotting_context` to get the parameter values.

    Parameters
    ----------
    context : dict, or one of {paper, notebook, talk, poster}
        A dictionary of parameters or the name of a preconfigured set.
    font_scale : float, optional
        Separate scaling factor to independently scale the size of the
        font elements.
    rc : dict, optional
        Parameter mappings to override the values in the preset seaborn
        context dictionaries. This only updates parameters that are
        considered part of the context definition.

    Examples
    --------

    .. include:: ../docstrings/set_context.rst

    """
    context_object = plotting_context(context, font_scale, rc)
    mpl.rcParams.update(context_object)


class _RCAesthetics(dict):
    def __enter__(self):
        rc = mpl.rcParams
        self._orig = {k: rc[k] for k in self._keys}
        self._set(self)

    def __exit__(self, exc_type, exc_value, exc_tb):
        self._set(self._orig)

    def __call__(self, func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            with self:
                return func(*args, **kwargs)
        return wrapper


class _AxesStyle(_RCAesthetics):
    """Light wrapper on a dict to set style temporarily."""
    _keys = _style_keys
    _set = staticmethod(set_style)


class _PlottingContext(_RCAesthetics):
    """Light wrapper on a dict to set context temporarily."""
    _keys = _context_keys
    _set = staticmethod(set_context)


def set_palette(palette, n_colors=None, desat=None, color_codes=False):
    """Set the matplotlib color cycle using a seaborn palette.

    Parameters
    ----------
    palette : seaborn color palette | matplotlib colormap | hls | husl
        Palette definition. Should be something :func:`color_palette` can process.
    n_colors : int
        Number of colors in the cycle. The default number of colors will depend
        on the format of ``palette``, see the :func:`color_palette`
        documentation for more information.
    desat : float
        Proportion to desaturate each color by.
    color_codes : bool
        If ``True`` and ``palette`` is a seaborn palette, remap the shorthand
        color codes (e.g. "b", "g", "r", etc.) to the colors from this palette.

    See Also
    --------
    color_palette : build a color palette or set the color cycle temporarily
                    in a ``with`` statement.
    set_context : set parameters to scale plot elements
    set_style : set the default parameters for figure style

    """
    colors = palettes.color_palette(palette, n_colors, desat)
    cyl = cycler('color', colors)
    mpl.rcParams['axes.prop_cycle'] = cyl
    if color_codes:
        try:
            palettes.set_color_codes(palette)
        except (ValueError, TypeError):
            pass


================================================
FILE: seaborn/regression.py
================================================
"""Plotting functions for linear models (broadly construed)."""
import copy
from textwrap import dedent
import warnings
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

try:
    import statsmodels
    assert statsmodels
    _has_statsmodels = True
except ImportError:
    _has_statsmodels = False

from . import utils
from . import algorithms as algo
from .axisgrid import FacetGrid, _facet_docs


__all__ = ["lmplot", "regplot", "residplot"]


class _LinearPlotter:
    """Base class for plotting relational data in tidy format.

    To get anything useful done you'll have to inherit from this, but setup
    code that can be abstracted out should be put here.

    """
    def establish_variables(self, data, **kws):
        """Extract variables from data or use directly."""
        self.data = data

        # Validate the inputs
        any_strings = any([isinstance(v, str) for v in kws.values()])
        if any_strings and data is None:
            raise ValueError("Must pass `data` if using named variables.")

        # Set the variables
        for var, val in kws.items():
            if isinstance(val, str):
                vector = data[val]
            elif isinstance(val, list):
                vector = np.asarray(val)
            else:
                vector = val
            if vector is not None and vector.shape != (1,):
                vector = np.squeeze(vector)
            if np.ndim(vector) > 1:
                err = "regplot inputs must be 1d"
                raise ValueError(err)
            setattr(self, var, vector)

    def dropna(self, *vars):
        """Remove observations with missing data."""
        vals = [getattr(self, var) for var in vars]
        vals = [v for v in vals if v is not None]
        not_na = np.all(np.column_stack([pd.notnull(v) for v in vals]), axis=1)
        for var in vars:
            val = getattr(self, var)
            if val is not None:
                setattr(self, var, val[not_na])

    def plot(self, ax):
        raise NotImplementedError


class _RegressionPlotter(_LinearPlotter):
    """Plotter for numeric independent variables with regression model.

    This does the computations and drawing for the `regplot` function, and
    is thus also used indirectly by `lmplot`.
    """
    def __init__(self, x, y, data=None, x_estimator=None, x_bins=None,
                 x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
                 units=None, seed=None, order=1, logistic=False, lowess=False,
                 robust=False, logx=False, x_partial=None, y_partial=None,
                 truncate=False, dropna=True, x_jitter=None, y_jitter=None,
                 color=None, label=None):

        # Set member attributes
        self.x_estimator = x_estimator
        self.ci = ci
        self.x_ci = ci if x_ci == "ci" else x_ci
        self.n_boot = n_boot
        self.seed = seed
        self.scatter = scatter
        self.fit_reg = fit_reg
        self.order = order
        self.logistic = logistic
        self.lowess = lowess
        self.robust = robust
        self.logx = logx
        self.truncate = truncate
        self.x_jitter = x_jitter
        self.y_jitter = y_jitter
        self.color = color
        self.label = label

        # Validate the regression options:
        if sum((order > 1, logistic, robust, lowess, logx)) > 1:
            raise ValueError("Mutually exclusive regression options.")

        # Extract the data vals from the arguments or passed dataframe
        self.establish_variables(data, x=x, y=y, units=units,
                                 x_partial=x_partial, y_partial=y_partial)

        # Drop null observations
        if dropna:
            self.dropna("x", "y", "units", "x_partial", "y_partial")

        # Regress nuisance variables out of the data
        if self.x_partial is not None:
            self.x = self.regress_out(self.x, self.x_partial)
        if self.y_partial is not None:
            self.y = self.regress_out(self.y, self.y_partial)

        # Possibly bin the predictor variable, which implies a point estimate
        if x_bins is not None:
            self.x_estimator = np.mean if x_estimator is None else x_estimator
            x_discrete, x_bins = self.bin_predictor(x_bins)
            self.x_discrete = x_discrete
        else:
            self.x_discrete = self.x

        # Disable regression in case of singleton inputs
        if len(self.x) <= 1:
            self.fit_reg = False

        # Save the range of the x variable for the grid later
        if self.fit_reg:
            self.x_range = self.x.min(), self.x.max()

    @property
    def scatter_data(self):
        """Data where each observation is a point."""
        x_j = self.x_jitter
        if x_j is None:
            x = self.x
        else:
            x = self.x + np.random.uniform(-x_j, x_j, len(self.x))

        y_j = self.y_jitter
        if y_j is None:
            y = self.y
        else:
            y = self.y + np.random.uniform(-y_j, y_j, len(self.y))

        return x, y

    @property
    def estimate_data(self):
        """Data with a point estimate and CI for each discrete x value."""
        x, y = self.x_discrete, self.y
        vals = sorted(np.unique(x))
        points, cis = [], []

        for val in vals:

            # Get the point estimate of the y variable
            _y = y[x == val]
            est = self.x_estimator(_y)
            points.append(est)

            # Compute the confidence interval for this estimate
            if self.x_ci is None:
                cis.append(None)
            else:
                units = None
                if self.x_ci == "sd":
                    sd = np.std(_y)
                    _ci = est - sd, est + sd
                else:
                    if self.units is not None:
                        units = self.units[x == val]
                    boots = algo.bootstrap(_y,
                                           func=self.x_estimator,
                                           n_boot=self.n_boot,
                                           units=units,
                                           seed=self.seed)
                    _ci = utils.ci(boots, self.x_ci)
                cis.append(_ci)

        return vals, points, cis

    def _check_statsmodels(self):
        """Check whether statsmodels is installed if any boolean options require it."""
        options = "logistic", "robust", "lowess"
        err = "`{}=True` requires statsmodels, an optional dependency, to be installed."
        for option in options:
            if getattr(self, option) and not _has_statsmodels:
                raise RuntimeError(err.format(option))

    def fit_regression(self, ax=None, x_range=None, grid=None):
        """Fit the regression model."""
        self._check_statsmodels()

        # Create the grid for the regression
        if grid is None:
            if self.truncate:
                x_min, x_max = self.x_range
            else:
                if ax is None:
                    x_min, x_max = x_range
                else:
                    x_min, x_max = ax.get_xlim()
            grid = np.linspace(x_min, x_max, 100)
        ci = self.ci

        # Fit the regression
        if self.order > 1:
            yhat, yhat_boots = self.fit_poly(grid, self.order)
        elif self.logistic:
            from statsmodels.genmod.generalized_linear_model import GLM
            from statsmodels.genmod.families import Binomial
            yhat, yhat_boots = self.fit_statsmodels(grid, GLM,
                                                    family=Binomial())
        elif self.lowess:
            ci = None
            grid, yhat = self.fit_lowess()
        elif self.robust:
            from statsmodels.robust.robust_linear_model import RLM
            yhat, yhat_boots = self.fit_statsmodels(grid, RLM)
        elif self.logx:
            yhat, yhat_boots = self.fit_logx(grid)
        else:
            yhat, yhat_boots = self.fit_fast(grid)

        # Compute the confidence interval at each grid point
        if ci is None:
            err_bands = None
        else:
            err_bands = utils.ci(yhat_boots, ci, axis=0)

        return grid, yhat, err_bands

    def fit_fast(self, grid):
        """Low-level regression and prediction using linear algebra."""
        def reg_func(_x, _y):
            return np.linalg.pinv(_x).dot(_y)

        X, y = np.c_[np.ones(len(self.x)), self.x], self.y
        grid = np.c_[np.ones(len(grid)), grid]
        yhat = grid.dot(reg_func(X, y))
        if self.ci is None:
            return yhat, None

        beta_boots = algo.bootstrap(X, y,
                                    func=reg_func,
                                    n_boot=self.n_boot,
                                    units=self.units,
                                    seed=self.seed).T
        yhat_boots = grid.dot(beta_boots).T
        return yhat, yhat_boots

    def fit_poly(self, grid, order):
        """Regression using numpy polyfit for higher-order trends."""
        def reg_func(_x, _y):
            return np.polyval(np.polyfit(_x, _y, order), grid)

        x, y = self.x, self.y
        yhat = reg_func(x, y)
        if self.ci is None:
            return yhat, None

        yhat_boots = algo.bootstrap(x, y,
                                    func=reg_func,
                                    n_boot=self.n_boot,
                                    units=self.units,
                                    seed=self.seed)
        return yhat, yhat_boots

    def fit_statsmodels(self, grid, model, **kwargs):
        """More general regression function using statsmodels objects."""
        import statsmodels.tools.sm_exceptions as sme
        X, y = np.c_[np.ones(len(self.x)), self.x], self.y
        grid = np.c_[np.ones(len(grid)), grid]

        def reg_func(_x, _y):
            err_classes = (sme.PerfectSeparationError,)
            try:
                with warnings.catch_warnings():
                    if hasattr(sme, "PerfectSeparationWarning"):
                        # statsmodels>=0.14.0
                        warnings.simplefilter("error", sme.PerfectSeparationWarning)
                        err_classes = (*err_classes, sme.PerfectSeparationWarning)
                    yhat = model(_y, _x, **kwargs).fit().predict(grid)
            except err_classes:
                yhat = np.empty(len(grid))
                yhat.fill(np.nan)
            return yhat

        yhat = reg_func(X, y)
        if self.ci is None:
            return yhat, None

        yhat_boots = algo.bootstrap(X, y,
                                    func=reg_func,
                                    n_boot=self.n_boot,
                                    units=self.units,
                                    seed=self.seed)
        return yhat, yhat_boots

    def fit_lowess(self):
        """Fit a locally-weighted regression, which returns its own grid."""
        from statsmodels.nonparametric.smoothers_lowess import lowess
        grid, yhat = lowess(self.y, self.x).T
        return grid, yhat

    def fit_logx(self, grid):
        """Fit the model in log-space."""
        X, y = np.c_[np.ones(len(self.x)), self.x], self.y
        grid = np.c_[np.ones(len(grid)), np.log(grid)]

        def reg_func(_x, _y):
            _x = np.c_[_x[:, 0], np.log(_x[:, 1])]
            return np.linalg.pinv(_x).dot(_y)

        yhat = grid.dot(reg_func(X, y))
        if self.ci is None:
            return yhat, None

        beta_boots = algo.bootstrap(X, y,
                                    func=reg_func,
                                    n_boot=self.n_boot,
                                    units=self.units,
                                    seed=self.seed).T
        yhat_boots = grid.dot(beta_boots).T
        return yhat, yhat_boots

    def bin_predictor(self, bins):
        """Discretize a predictor by assigning value to closest bin."""
        x = np.asarray(self.x)
        if np.isscalar(bins):
            percentiles = np.linspace(0, 100, bins + 2)[1:-1]
            bins = np.percentile(x, percentiles)
        else:
            bins = np.ravel(bins)

        dist = np.abs(np.subtract.outer(x, bins))
        x_binned = bins[np.argmin(dist, axis=1)].ravel()

        return x_binned, bins

    def regress_out(self, a, b):
        """Regress b from a keeping a's original mean."""
        a_mean = a.mean()
        a = a - a_mean
        b = b - b.mean()
        b = np.c_[b]
        a_prime = a - b.dot(np.linalg.pinv(b).dot(a))
        return np.asarray(a_prime + a_mean).reshape(a.shape)

    def plot(self, ax, scatter_kws, line_kws):
        """Draw the full plot."""
        # Insert the plot label into the correct set of keyword arguments
        if self.scatter:
            scatter_kws["label"] = self.label
        else:
            line_kws["label"] = self.label

        # Use the current color cycle state as a default
        if self.color is None:
            lines, = ax.plot([], [])
            color = lines.get_color()
            lines.remove()
        else:
            color = self.color

        # Ensure that color is hex to avoid matplotlib weirdness
        color = mpl.colors.rgb2hex(mpl.colors.colorConverter.to_rgb(color))

        # Let color in keyword arguments override overall plot color
        scatter_kws.setdefault("color", color)
        line_kws.setdefault("color", color)

        # Draw the constituent plots
        if self.scatter:
            self.scatterplot(ax, scatter_kws)

        if self.fit_reg:
            self.lineplot(ax, line_kws)

        # Label the axes
        if hasattr(self.x, "name"):
            ax.set_xlabel(self.x.name)
        if hasattr(self.y, "name"):
            ax.set_ylabel(self.y.name)

    def scatterplot(self, ax, kws):
        """Draw the data."""
        # Treat the line-based markers specially, explicitly setting larger
        # linewidth than is provided by the seaborn style defaults.
        # This would ideally be handled better in matplotlib (i.e., distinguish
        # between edgewidth for solid glyphs and linewidth for line glyphs
        # but this should do for now.
        line_markers = ["1", "2", "3", "4", "+", "x", "|", "_"]
        if self.x_estimator is None:
            if "marker" in kws and kws["marker"] in line_markers:
                lw = mpl.rcParams["lines.linewidth"]
            else:
                lw = mpl.rcParams["lines.markeredgewidth"]
            kws.setdefault("linewidths", lw)

            if not hasattr(kws['color'], 'shape') or kws['color'].shape[1] < 4:
                kws.setdefault("alpha", .8)

            x, y = self.scatter_data
            ax.scatter(x, y, **kws)
        else:
            # TODO abstraction
            ci_kws = {"color": kws["color"]}
            if "alpha" in kws:
                ci_kws["alpha"] = kws["alpha"]
            ci_kws["linewidth"] = mpl.rcParams["lines.linewidth"] * 1.75
            kws.setdefault("s", 50)

            xs, ys, cis = self.estimate_data
            if [ci for ci in cis if ci is not None]:
                for x, ci in zip(xs, cis):
                    ax.plot([x, x], ci, **ci_kws)
            ax.scatter(xs, ys, **kws)

    def lineplot(self, ax, kws):
        """Draw the model."""
        # Fit the regression model
        grid, yhat, err_bands = self.fit_regression(ax)
        edges = grid[0], grid[-1]

        # Get set default aesthetics
        fill_color = kws["color"]
        lw = kws.pop("lw", mpl.rcParams["lines.linewidth"] * 1.5)
        kws.setdefault("linewidth", lw)

        # Draw the regression line and confidence interval
        line, = ax.plot(grid, yhat, **kws)
        if not self.truncate:
            line.sticky_edges.x[:] = edges  # Prevent mpl from adding margin
        if err_bands is not None:
            ax.fill_between(grid, *err_bands, facecolor=fill_color, alpha=.15)


_regression_docs = dict(

    model_api=dedent("""\
    There are a number of mutually exclusive options for estimating the
    regression model. See the :ref:`tutorial ` for more
    information.\
    """),
    regplot_vs_lmplot=dedent("""\
    The :func:`regplot` and :func:`lmplot` functions are closely related, but
    the former is an axes-level function while the latter is a figure-level
    function that combines :func:`regplot` and :class:`FacetGrid`.\
    """),
    x_estimator=dedent("""\
    x_estimator : callable that maps vector -> scalar, optional
        Apply this function to each unique value of ``x`` and plot the
        resulting estimate. This is useful when ``x`` is a discrete variable.
        If ``x_ci`` is given, this estimate will be bootstrapped and a
        confidence interval will be drawn.\
    """),
    x_bins=dedent("""\
    x_bins : int or vector, optional
        Bin the ``x`` variable into discrete bins and then estimate the central
        tendency and a confidence interval. This binning only influences how
        the scatterplot is drawn; the regression is still fit to the original
        data.  This parameter is interpreted either as the number of
        evenly-sized (not necessary spaced) bins or the positions of the bin
        centers. When this parameter is used, it implies that the default of
        ``x_estimator`` is ``numpy.mean``.\
    """),
    x_ci=dedent("""\
    x_ci : "ci", "sd", int in [0, 100] or None, optional
        Size of the confidence interval used when plotting a central tendency
        for discrete values of ``x``. If ``"ci"``, defer to the value of the
        ``ci`` parameter. If ``"sd"``, skip bootstrapping and show the
        standard deviation of the observations in each bin.\
    """),
    scatter=dedent("""\
    scatter : bool, optional
        If ``True``, draw a scatterplot with the underlying observations (or
        the ``x_estimator`` values).\
    """),
    fit_reg=dedent("""\
    fit_reg : bool, optional
        If ``True``, estimate and plot a regression model relating the ``x``
        and ``y`` variables.\
    """),
    ci=dedent("""\
    ci : int in [0, 100] or None, optional
        Size of the confidence interval for the regression estimate. This will
        be drawn using translucent bands around the regression line. The
        confidence interval is estimated using a bootstrap; for large
        datasets, it may be advisable to avoid that computation by setting
        this parameter to None.\
    """),
    n_boot=dedent("""\
    n_boot : int, optional
        Number of bootstrap resamples used to estimate the ``ci``. The default
        value attempts to balance time and stability; you may want to increase
        this value for "final" versions of plots.\
    """),
    units=dedent("""\
    units : variable name in ``data``, optional
        If the ``x`` and ``y`` observations are nested within sampling units,
        those can be specified here. This will be taken into account when
        computing the confidence intervals by performing a multilevel bootstrap
        that resamples both units and observations (within unit). This does not
        otherwise influence how the regression is estimated or drawn.\
    """),
    seed=dedent("""\
    seed : int, numpy.random.Generator, or numpy.random.RandomState, optional
        Seed or random number generator for reproducible bootstrapping.\
    """),
    order=dedent("""\
    order : int, optional
        If ``order`` is greater than 1, use ``numpy.polyfit`` to estimate a
        polynomial regression.\
    """),
    logistic=dedent("""\
    logistic : bool, optional
        If ``True``, assume that ``y`` is a binary variable and use
        ``statsmodels`` to estimate a logistic regression model. Note that this
        is substantially more computationally intensive than linear regression,
        so you may wish to decrease the number of bootstrap resamples
        (``n_boot``) or set ``ci`` to None.\
    """),
    lowess=dedent("""\
    lowess : bool, optional
        If ``True``, use ``statsmodels`` to estimate a nonparametric lowess
        model (locally weighted linear regression). Note that confidence
        intervals cannot currently be drawn for this kind of model.\
    """),
    robust=dedent("""\
    robust : bool, optional
        If ``True``, use ``statsmodels`` to estimate a robust regression. This
        will de-weight outliers. Note that this is substantially more
        computationally intensive than standard linear regression, so you may
        wish to decrease the number of bootstrap resamples (``n_boot``) or set
        ``ci`` to None.\
    """),
    logx=dedent("""\
    logx : bool, optional
        If ``True``, estimate a linear regression of the form y ~ log(x), but
        plot the scatterplot and regression model in the input space. Note that
        ``x`` must be positive for this to work.\
    """),
    xy_partial=dedent("""\
    {x,y}_partial : strings in ``data`` or matrices
        Confounding variables to regress out of the ``x`` or ``y`` variables
        before plotting.\
    """),
    truncate=dedent("""\
    truncate : bool, optional
        If ``True``, the regression line is bounded by the data limits. If
        ``False``, it extends to the ``x`` axis limits.
    """),
    dropna=dedent("""\
    dropna : bool, optional
        If ``True``, remove observations with missing data from the plot.
    """),
    xy_jitter=dedent("""\
    {x,y}_jitter : floats, optional
        Add uniform random noise of this size to either the ``x`` or ``y``
        variables. The noise is added to a copy of the data after fitting the
        regression, and only influences the look of the scatterplot. This can
        be helpful when plotting variables that take discrete values.\
    """),
    scatter_line_kws=dedent("""\
    {scatter,line}_kws : dictionaries
        Additional keyword arguments to pass to ``plt.scatter`` and
        ``plt.plot``.\
    """),
)
_regression_docs.update(_facet_docs)


def lmplot(
    data, *,
    x=None, y=None, hue=None, col=None, row=None,
    palette=None, col_wrap=None, height=5, aspect=1, markers="o",
    sharex=None, sharey=None, hue_order=None, col_order=None, row_order=None,
    legend=True, legend_out=None, x_estimator=None, x_bins=None,
    x_ci="ci", scatter=True, fit_reg=True, ci=95, n_boot=1000,
    units=None, seed=None, order=1, logistic=False, lowess=False,
    robust=False, logx=False, x_partial=None, y_partial=None,
    truncate=True, x_jitter=None, y_jitter=None, scatter_kws=None,
    line_kws=None, facet_kws=None,
):

    if facet_kws is None:
        facet_kws = {}

    def facet_kw_deprecation(key, val):
        msg = (
            f"{key} is deprecated from the `lmplot` function signature. "
            "Please update your code to pass it using `facet_kws`."
        )
        if val is not None:
            warnings.warn(msg, UserWarning)
            facet_kws[key] = val

    facet_kw_deprecation("sharex", sharex)
    facet_kw_deprecation("sharey", sharey)
    facet_kw_deprecation("legend_out", legend_out)

    if data is None:
        raise TypeError("Missing required keyword argument `data`.")

    # Reduce the dataframe to only needed columns
    need_cols = [x, y, hue, col, row, units, x_partial, y_partial]
    cols = np.unique([a for a in need_cols if a is not None]).tolist()
    data = data[cols]

    # Initialize the grid
    facets = FacetGrid(
        data, row=row, col=col, hue=hue,
        palette=palette,
        row_order=row_order, col_order=col_order, hue_order=hue_order,
        height=height, aspect=aspect, col_wrap=col_wrap,
        **facet_kws,
    )

    # Add the markers here as FacetGrid has figured out how many levels of the
    # hue variable are needed and we don't want to duplicate that process
    if facets.hue_names is None:
        n_markers = 1
    else:
        n_markers = len(facets.hue_names)
    if not isinstance(markers, list):
        markers = [markers] * n_markers
    if len(markers) != n_markers:
        raise ValueError("markers must be a singleton or a list of markers "
                         "for each level of the hue variable")
    facets.hue_kws = {"marker": markers}

    def update_datalim(data, x, y, ax, **kws):
        xys = data[[x, y]].to_numpy().astype(float)
        ax.update_datalim(xys, updatey=False)
        ax.autoscale_view(scaley=False)

    facets.map_dataframe(update_datalim, x=x, y=y)

    # Draw the regression plot on each facet
    regplot_kws = dict(
        x_estimator=x_estimator, x_bins=x_bins, x_ci=x_ci,
        scatter=scatter, fit_reg=fit_reg, ci=ci, n_boot=n_boot, units=units,
        seed=seed, order=order, logistic=logistic, lowess=lowess,
        robust=robust, logx=logx, x_partial=x_partial, y_partial=y_partial,
        truncate=truncate, x_jitter=x_jitter, y_jitter=y_jitter,
        scatter_kws=scatter_kws, line_kws=line_kws,
    )
    facets.map_dataframe(regplot, x=x, y=y, **regplot_kws)
    facets.set_axis_labels(x, y)

    # Add a legend
    if legend and (hue is not None) and (hue not in [col, row]):
        facets.add_legend()
    return facets


lmplot.__doc__ = dedent("""\
    Plot data and regression model fits across a FacetGrid.

    This function combines :func:`regplot` and :class:`FacetGrid`. It is
    intended as a convenient interface to fit regression models across
    conditional subsets of a dataset.

    When thinking about how to assign variables to different facets, a general
    rule is that it makes sense to use ``hue`` for the most important
    comparison, followed by ``col`` and ``row``. However, always think about
    your particular dataset and the goals of the visualization you are
    creating.

    {model_api}

    The parameters to this function span most of the options in
    :class:`FacetGrid`, although there may be occasional cases where you will
    want to use that class and :func:`regplot` directly.

    Parameters
    ----------
    {data}
    x, y : strings, optional
        Input variables; these should be column names in ``data``.
    hue, col, row : strings
        Variables that define subsets of the data, which will be drawn on
        separate facets in the grid. See the ``*_order`` parameters to control
        the order of levels of this variable.
    {palette}
    {col_wrap}
    {height}
    {aspect}
    markers : matplotlib marker code or list of marker codes, optional
        Markers for the scatterplot. If a list, each marker in the list will be
        used for each level of the ``hue`` variable.
    {share_xy}

        .. deprecated:: 0.12.0
            Pass using the `facet_kws` dictionary.

    {{hue,col,row}}_order : lists, optional
        Order for the levels of the faceting variables. By default, this will
        be the order that the levels appear in ``data`` or, if the variables
        are pandas categoricals, the category order.
    legend : bool, optional
        If ``True`` and there is a ``hue`` variable, add a legend.
    {legend_out}

        .. deprecated:: 0.12.0
            Pass using the `facet_kws` dictionary.

    {x_estimator}
    {x_bins}
    {x_ci}
    {scatter}
    {fit_reg}
    {ci}
    {n_boot}
    {units}
    {seed}
    {order}
    {logistic}
    {lowess}
    {robust}
    {logx}
    {xy_partial}
    {truncate}
    {xy_jitter}
    {scatter_line_kws}
    facet_kws : dict
        Dictionary of keyword arguments for :class:`FacetGrid`.

    Returns
    -------
    :class:`FacetGrid`
        The :class:`FacetGrid` object with the plot on it for further tweaking.

    See Also
    --------
    regplot : Plot data and a conditional model fit.
    FacetGrid : Subplot grid for plotting conditional relationships.
    pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
               ``kind="reg"``).

    Notes
    -----

    {regplot_vs_lmplot}

    Examples
    --------

    .. include:: ../docstrings/lmplot.rst

    """).format(**_regression_docs)


def regplot(
    data=None, *, x=None, y=None,
    x_estimator=None, x_bins=None, x_ci="ci",
    scatter=True, fit_reg=True, ci=95, n_boot=1000, units=None,
    seed=None, order=1, logistic=False, lowess=False, robust=False,
    logx=False, x_partial=None, y_partial=None,
    truncate=True, dropna=True, x_jitter=None, y_jitter=None,
    label=None, color=None, marker="o",
    scatter_kws=None, line_kws=None, ax=None
):

    plotter = _RegressionPlotter(x, y, data, x_estimator, x_bins, x_ci,
                                 scatter, fit_reg, ci, n_boot, units, seed,
                                 order, logistic, lowess, robust, logx,
                                 x_partial, y_partial, truncate, dropna,
                                 x_jitter, y_jitter, color, label)

    if ax is None:
        ax = plt.gca()

    scatter_kws = {} if scatter_kws is None else copy.copy(scatter_kws)
    scatter_kws["marker"] = marker
    line_kws = {} if line_kws is None else copy.copy(line_kws)
    plotter.plot(ax, scatter_kws, line_kws)
    return ax


regplot.__doc__ = dedent("""\
    Plot data and a linear regression model fit.

    {model_api}

    Parameters
    ----------
    x, y : string, series, or vector array
        Input variables. If strings, these should correspond with column names
        in ``data``. When pandas objects are used, axes will be labeled with
        the series name.
    {data}
    {x_estimator}
    {x_bins}
    {x_ci}
    {scatter}
    {fit_reg}
    {ci}
    {n_boot}
    {units}
    {seed}
    {order}
    {logistic}
    {lowess}
    {robust}
    {logx}
    {xy_partial}
    {truncate}
    {dropna}
    {xy_jitter}
    label : string
        Label to apply to either the scatterplot or regression line (if
        ``scatter`` is ``False``) for use in a legend.
    color : matplotlib color
        Color to apply to all plot elements; will be superseded by colors
        passed in ``scatter_kws`` or ``line_kws``.
    marker : matplotlib marker code
        Marker to use for the scatterplot glyphs.
    {scatter_line_kws}
    ax : matplotlib Axes, optional
        Axes object to draw the plot onto, otherwise uses the current Axes.

    Returns
    -------
    ax : matplotlib Axes
        The Axes object containing the plot.

    See Also
    --------
    lmplot : Combine :func:`regplot` and :class:`FacetGrid` to plot multiple
             linear relationships in a dataset.
    jointplot : Combine :func:`regplot` and :class:`JointGrid` (when used with
                ``kind="reg"``).
    pairplot : Combine :func:`regplot` and :class:`PairGrid` (when used with
               ``kind="reg"``).
    residplot : Plot the residuals of a linear regression model.

    Notes
    -----

    {regplot_vs_lmplot}


    It's also easy to combine :func:`regplot` and :class:`JointGrid` or
    :class:`PairGrid` through the :func:`jointplot` and :func:`pairplot`
    functions, although these do not directly accept all of :func:`regplot`'s
    parameters.

    Examples
    --------

    .. include:: ../docstrings/regplot.rst

    """).format(**_regression_docs)


def residplot(
    data=None, *, x=None, y=None,
    x_partial=None, y_partial=None, lowess=False,
    order=1, robust=False, dropna=True, label=None, color=None,
    scatter_kws=None, line_kws=None, ax=None
):
    """Plot the residuals of a linear regression.

    This function will regress y on x (possibly as a robust or polynomial
    regression) and then draw a scatterplot of the residuals. You can
    optionally fit a lowess smoother to the residual plot, which can
    help in determining if there is structure to the residuals.

    Parameters
    ----------
    data : DataFrame, optional
        DataFrame to use if `x` and `y` are column names.
    x : vector or string
        Data or column name in `data` for the predictor variable.
    y : vector or string
        Data or column name in `data` for the response variable.
    {x, y}_partial : vectors or string(s) , optional
        These variables are treated as confounding and are removed from
        the `x` or `y` variables before plotting.
    lowess : boolean, optional
        Fit a lowess smoother to the residual scatterplot.
    order : int, optional
        Order of the polynomial to fit when calculating the residuals.
    robust : boolean, optional
        Fit a robust linear regression when calculating the residuals.
    dropna : boolean, optional
        If True, ignore observations with missing data when fitting and
        plotting.
    label : string, optional
        Label that will be used in any plot legends.
    color : matplotlib color, optional
        Color to use for all elements of the plot.
    {scatter, line}_kws : dictionaries, optional
        Additional keyword arguments passed to scatter() and plot() for drawing
        the components of the plot.
    ax : matplotlib axis, optional
        Plot into this axis, otherwise grab the current axis or make a new
        one if not existing.

    Returns
    -------
    ax: matplotlib axes
        Axes with the regression plot.

    See Also
    --------
    regplot : Plot a simple linear regression model.
    jointplot : Draw a :func:`residplot` with univariate marginal distributions
                (when used with ``kind="resid"``).

    Examples
    --------

    .. include:: ../docstrings/residplot.rst

    """
    plotter = _RegressionPlotter(x, y, data, ci=None,
                                 order=order, robust=robust,
                                 x_partial=x_partial, y_partial=y_partial,
                                 dropna=dropna, color=color, label=label)

    if ax is None:
        ax = plt.gca()

    # Calculate the residual from a linear regression
    _, yhat, _ = plotter.fit_regression(grid=plotter.x)
    plotter.y = plotter.y - yhat

    # Set the regression option on the plotter
    if lowess:
        plotter.lowess = True
    else:
        plotter.fit_reg = False

    # Plot a horizontal line at 0
    ax.axhline(0, ls=":", c=".2")

    # Draw the scatterplot
    scatter_kws = {} if scatter_kws is None else scatter_kws.copy()
    line_kws = {} if line_kws is None else line_kws.copy()
    plotter.plot(ax, scatter_kws, line_kws)
    return ax


================================================
FILE: seaborn/relational.py
================================================
from functools import partial
import warnings

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.cbook import normalize_kwargs

from ._base import (
    VectorPlotter,
)
from .utils import (
    adjust_legend_subtitles,
    _default_color,
    _deprecate_ci,
    _get_transform_functions,
    _scatter_legend_artist,
)
from ._compat import groupby_apply_include_groups
from ._statistics import EstimateAggregator, WeightedAggregator
from .axisgrid import FacetGrid, _facet_docs
from ._docstrings import DocstringComponents, _core_docs


__all__ = ["relplot", "scatterplot", "lineplot"]


_relational_narrative = DocstringComponents(dict(

    # ---  Introductory prose
    main_api="""
The relationship between `x` and `y` can be shown for different subsets
of the data using the `hue`, `size`, and `style` parameters. These
parameters control what visual semantics are used to identify the different
subsets. It is possible to show up to three dimensions independently by
using all three semantic types, but this style of plot can be hard to
interpret and is often ineffective. Using redundant semantics (i.e. both
`hue` and `style` for the same variable) can be helpful for making
graphics more accessible.

See the :ref:`tutorial ` for more information.
    """,

    relational_semantic="""
The default treatment of the `hue` (and to a lesser extent, `size`)
semantic, if present, depends on whether the variable is inferred to
represent "numeric" or "categorical" data. In particular, numeric variables
are represented with a sequential colormap by default, and the legend
entries show regular "ticks" with values that may or may not exist in the
data. This behavior can be controlled through various parameters, as
described and illustrated below.
    """,
))

_relational_docs = dict(

    # --- Shared function parameters
    data_vars="""
x, y : names of variables in `data` or vector data
    Input data variables; must be numeric. Can pass data directly or
    reference columns in `data`.
    """,
    data="""
data : DataFrame, array, or list of arrays
    Input data structure. If `x` and `y` are specified as names, this
    should be a "long-form" DataFrame containing those columns. Otherwise
    it is treated as "wide-form" data and grouping variables are ignored.
    See the examples for the various ways this parameter can be specified
    and the different effects of each.
    """,
    palette="""
palette : string, list, dict, or matplotlib colormap
    An object that determines how colors are chosen when `hue` is used.
    It can be the name of a seaborn palette or matplotlib colormap, a list
    of colors (anything matplotlib understands), a dict mapping levels
    of the `hue` variable to colors, or a matplotlib colormap object.
    """,
    hue_order="""
hue_order : list
    Specified order for the appearance of the `hue` variable levels,
    otherwise they are determined from the data. Not relevant when the
    `hue` variable is numeric.
    """,
    hue_norm="""
hue_norm : tuple or :class:`matplotlib.colors.Normalize` object
    Normalization in data units for colormap applied to the `hue`
    variable when it is numeric. Not relevant if `hue` is categorical.
    """,
    sizes="""
sizes : list, dict, or tuple
    An object that determines how sizes are chosen when `size` is used.
    List or dict arguments should provide a size for each unique data value,
    which forces a categorical interpretation. The argument may also be a
    min, max tuple.
    """,
    size_order="""
size_order : list
    Specified order for appearance of the `size` variable levels,
    otherwise they are determined from the data. Not relevant when the
    `size` variable is numeric.
    """,
    size_norm="""
size_norm : tuple or Normalize object
    Normalization in data units for scaling plot objects when the
    `size` variable is numeric.
    """,
    dashes="""
dashes : boolean, list, or dictionary
    Object determining how to draw the lines for different levels of the
    `style` variable. Setting to `True` will use default dash codes, or
    you can pass a list of dash codes or a dictionary mapping levels of the
    `style` variable to dash codes. Setting to `False` will use solid
    lines for all subsets. Dashes are specified as in matplotlib: a tuple
    of `(segment, gap)` lengths, or an empty string to draw a solid line.
    """,
    markers="""
markers : boolean, list, or dictionary
    Object determining how to draw the markers for different levels of the
    `style` variable. Setting to `True` will use default markers, or
    you can pass a list of markers or a dictionary mapping levels of the
    `style` variable to markers. Setting to `False` will draw
    marker-less lines.  Markers are specified as in matplotlib.
    """,
    style_order="""
style_order : list
    Specified order for appearance of the `style` variable levels
    otherwise they are determined from the data. Not relevant when the
    `style` variable is numeric.
    """,
    units="""
units : vector or key in `data`
    Grouping variable identifying sampling units. When used, a separate
    line will be drawn for each unit with appropriate semantics, but no
    legend entry will be added. Useful for showing distribution of
    experimental replicates when exact identities are not needed.
    """,
    estimator="""
estimator : name of pandas method or callable or None
    Method for aggregating across multiple observations of the `y`
    variable at the same `x` level. If `None`, all observations will
    be drawn.
    """,
    ci="""
ci : int or "sd" or None
    Size of the confidence interval to draw when aggregating.

    .. deprecated:: 0.12.0
        Use the new `errorbar` parameter for more flexibility.

    """,
    n_boot="""
n_boot : int
    Number of bootstraps to use for computing the confidence interval.
    """,
    seed="""
seed : int, numpy.random.Generator, or numpy.random.RandomState
    Seed or random number generator for reproducible bootstrapping.
    """,
    legend="""
legend : "auto", "brief", "full", or False
    How to draw the legend. If "brief", numeric `hue` and `size`
    variables will be represented with a sample of evenly spaced values.
    If "full", every group will get an entry in the legend. If "auto",
    choose between brief or full representation based on number of levels.
    If `False`, no legend data is added and no legend is drawn.
    """,
    ax_in="""
ax : matplotlib Axes
    Axes object to draw the plot onto, otherwise uses the current Axes.
    """,
    ax_out="""
ax : matplotlib Axes
    Returns the Axes object with the plot drawn onto it.
    """,

)


_param_docs = DocstringComponents.from_nested_components(
    core=_core_docs["params"],
    facets=DocstringComponents(_facet_docs),
    rel=DocstringComponents(_relational_docs),
    stat=DocstringComponents.from_function_params(EstimateAggregator.__init__),
)


class _RelationalPlotter(VectorPlotter):

    wide_structure = {
        "x": "@index", "y": "@values", "hue": "@columns", "style": "@columns",
    }

    # TODO where best to define default parameters?
    sort = True


class _LinePlotter(_RelationalPlotter):

    _legend_attributes = ["color", "linewidth", "marker", "dashes"]

    def __init__(
        self, *,
        data=None, variables={},
        estimator=None, n_boot=None, seed=None, errorbar=None,
        sort=True, orient="x", err_style=None, err_kws=None, legend=None
    ):

        # TODO this is messy, we want the mapping to be agnostic about
        # the kind of plot to draw, but for the time being we need to set
        # this information so the SizeMapping can use it
        self._default_size_range = (
            np.r_[.5, 2] * mpl.rcParams["lines.linewidth"]
        )

        super().__init__(data=data, variables=variables)

        self.estimator = estimator
        self.errorbar = errorbar
        self.n_boot = n_boot
        self.seed = seed
        self.sort = sort
        self.orient = orient
        self.err_style = err_style
        self.err_kws = {} if err_kws is None else err_kws

        self.legend = legend

    def plot(self, ax, kws):
        """Draw the plot onto an axes, passing matplotlib kwargs."""

        # Draw a test plot, using the passed in kwargs. The goal here is to
        # honor both (a) the current state of the plot cycler and (b) the
        # specified kwargs on all the lines we will draw, overriding when
        # relevant with the data semantics. Note that we won't cycle
        # internally; in other words, if `hue` is not used, all elements will
        # have the same color, but they will have the color that you would have
        # gotten from the corresponding matplotlib function, and calling the
        # function will advance the axes property cycle.

        kws = normalize_kwargs(kws, mpl.lines.Line2D)
        kws.setdefault("markeredgewidth", 0.75)
        kws.setdefault("markeredgecolor", "w")

        # Set default error kwargs
        err_kws = self.err_kws.copy()
        if self.err_style == "band":
            err_kws.setdefault("alpha", .2)
        elif self.err_style == "bars":
            pass
        elif self.err_style is not None:
            err = "`err_style` must be 'band' or 'bars', not {}"
            raise ValueError(err.format(self.err_style))

        # Initialize the aggregation object
        weighted = "weight" in self.plot_data
        agg = (WeightedAggregator if weighted else EstimateAggregator)(
            self.estimator, self.errorbar, n_boot=self.n_boot, seed=self.seed,
        )

        # TODO abstract variable to aggregate over here-ish. Better name?
        orient = self.orient
        if orient not in {"x", "y"}:
            err = f"`orient` must be either 'x' or 'y', not {orient!r}."
            raise ValueError(err)
        other = {"x": "y", "y": "x"}[orient]

        # TODO How to handle NA? We don't want NA to propagate through to the
        # estimate/CI when some values are present, but we would also like
        # matplotlib to show "gaps" in the line when all values are missing.
        # This is straightforward absent aggregation, but complicated with it.
        # If we want to use nas, we need to conditionalize dropna in iter_data.

        # Loop over the semantic subsets and add to the plot
        grouping_vars = "hue", "size", "style"
        for sub_vars, sub_data in self.iter_data(grouping_vars, from_comp_data=True):

            if self.sort:
                sort_vars = ["units", orient, other]
                sort_cols = [var for var in sort_vars if var in self.variables]
                sub_data = sub_data.sort_values(sort_cols)

            if (
                self.estimator is not None
                and sub_data[orient].value_counts().max() > 1
            ):
                if "units" in self.variables:
                    # TODO eventually relax this constraint
                    err = "estimator must be None when specifying units"
                    raise ValueError(err)
                grouped = sub_data.groupby(orient, sort=self.sort)
                # Could pass as_index=False instead of reset_index,
                # but that fails on a corner case with older pandas.
                sub_data = (
                    grouped
                    .apply(agg, other, **groupby_apply_include_groups(False))
                    .reset_index()
                )
            else:
                sub_data[f"{other}min"] = np.nan
                sub_data[f"{other}max"] = np.nan

            # Apply inverse axis scaling
            for var in "xy":
                _, inv = _get_transform_functions(ax, var)
                for col in sub_data.filter(regex=f"^{var}"):
                    sub_data[col] = inv(sub_data[col])

            # --- Draw the main line(s)

            if "units" in self.variables:   # XXX why not add to grouping variables?
                lines = []
                for _, unit_data in sub_data.groupby("units"):
                    lines.extend(ax.plot(unit_data["x"], unit_data["y"], **kws))
            else:
                lines = ax.plot(sub_data["x"], sub_data["y"], **kws)

            for line in lines:

                if "hue" in sub_vars:
                    line.set_color(self._hue_map(sub_vars["hue"]))

                if "size" in sub_vars:
                    line.set_linewidth(self._size_map(sub_vars["size"]))

                if "style" in sub_vars:
                    attributes = self._style_map(sub_vars["style"])
                    if "dashes" in attributes:
                        line.set_dashes(attributes["dashes"])
                    if "marker" in attributes:
                        line.set_marker(attributes["marker"])

            line_color = line.get_color()
            line_alpha = line.get_alpha()
            line_capstyle = line.get_solid_capstyle()

            # --- Draw the confidence intervals

            if self.estimator is not None and self.errorbar is not None:

                # TODO handling of orientation will need to happen here

                if self.err_style == "band":

                    func = {"x": ax.fill_between, "y": ax.fill_betweenx}[orient]
                    func(
                        sub_data[orient],
                        sub_data[f"{other}min"], sub_data[f"{other}max"],
                        color=line_color, **err_kws
                    )

                elif self.err_style == "bars":

                    error_param = {
                        f"{other}err": (
                            sub_data[other] - sub_data[f"{other}min"],
                            sub_data[f"{other}max"] - sub_data[other],
                        )
                    }
                    ebars = ax.errorbar(
                        sub_data["x"], sub_data["y"], **error_param,
                        linestyle="", color=line_color, alpha=line_alpha,
                        **err_kws
                    )

                    # Set the capstyle properly on the error bars
                    for obj in ebars.get_children():
                        if isinstance(obj, mpl.collections.LineCollection):
                            obj.set_capstyle(line_capstyle)

        # Finalize the axes details
        self._add_axis_labels(ax)
        if self.legend:
            legend_artist = partial(mpl.lines.Line2D, xdata=[], ydata=[])
            attrs = {"hue": "color", "size": "linewidth", "style": None}
            self.add_legend_data(ax, legend_artist, kws, attrs)
            handles, _ = ax.get_legend_handles_labels()
            if handles:
                legend = ax.legend(title=self.legend_title)
                adjust_legend_subtitles(legend)


class _ScatterPlotter(_RelationalPlotter):

    _legend_attributes = ["color", "s", "marker"]

    def __init__(self, *, data=None, variables={}, legend=None):

        # TODO this is messy, we want the mapping to be agnostic about
        # the kind of plot to draw, but for the time being we need to set
        # this information so the SizeMapping can use it
        self._default_size_range = (
            np.r_[.5, 2] * np.square(mpl.rcParams["lines.markersize"])
        )

        super().__init__(data=data, variables=variables)

        self.legend = legend

    def plot(self, ax, kws):

        # --- Determine the visual attributes of the plot

        data = self.comp_data.dropna()
        if data.empty:
            return

        kws = normalize_kwargs(kws, mpl.collections.PathCollection)

        # Define the vectors of x and y positions
        empty = np.full(len(data), np.nan)
        x = data.get("x", empty)
        y = data.get("y", empty)

        # Apply inverse scaling to the coordinate variables
        _, inv_x = _get_transform_functions(ax, "x")
        _, inv_y = _get_transform_functions(ax, "y")
        x, y = inv_x(x), inv_y(y)

        if "style" in self.variables:
            # Use a representative marker so scatter sets the edgecolor
            # properly for line art markers. We currently enforce either
            # all or none line art so this works.
            example_level = self._style_map.levels[0]
            example_marker = self._style_map(example_level, "marker")
            kws.setdefault("marker", example_marker)

        # Conditionally set the marker edgecolor based on whether the marker is "filled"
        # See https://github.com/matplotlib/matplotlib/issues/17849 for context
        m = kws.get("marker", mpl.rcParams.get("marker", "o"))
        if not isinstance(m, mpl.markers.MarkerStyle):
            # TODO in more recent matplotlib (which?) can pass a MarkerStyle here
            m = mpl.markers.MarkerStyle(m)
        if m.is_filled():
            kws.setdefault("edgecolor", "w")

        # Draw the scatter plot
        points = ax.scatter(x=x, y=y, **kws)

        # Apply the mapping from semantic variables to artist attributes

        if "hue" in self.variables:
            points.set_facecolors(self._hue_map(data["hue"]))

        if "size" in self.variables:
            points.set_sizes(self._size_map(data["size"]))

        if "style" in self.variables:
            p = [self._style_map(val, "path") for val in data["style"]]
            points.set_paths(p)

        # Apply dependent default attributes

        if "linewidth" not in kws:
            sizes = points.get_sizes()
            linewidth = .08 * np.sqrt(np.percentile(sizes, 10))
            points.set_linewidths(linewidth)
            kws["linewidth"] = linewidth

        # Finalize the axes details
        self._add_axis_labels(ax)
        if self.legend:
            attrs = {"hue": "color", "size": "s", "style": None}
            self.add_legend_data(ax, _scatter_legend_artist, kws, attrs)
            handles, _ = ax.get_legend_handles_labels()
            if handles:
                legend = ax.legend(title=self.legend_title)
                adjust_legend_subtitles(legend)


def lineplot(
    data=None, *,
    x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
    palette=None, hue_order=None, hue_norm=None,
    sizes=None, size_order=None, size_norm=None,
    dashes=True, markers=None, style_order=None,
    estimator="mean", errorbar=("ci", 95), n_boot=1000, seed=None,
    orient="x", sort=True, err_style="band", err_kws=None,
    legend="auto", ci="deprecated", ax=None, **kwargs
):

    # Handle deprecation of ci parameter
    errorbar = _deprecate_ci(errorbar, ci)

    p = _LinePlotter(
        data=data,
        variables=dict(
            x=x, y=y, hue=hue, size=size, style=style, units=units, weight=weights
        ),
        estimator=estimator, n_boot=n_boot, seed=seed, errorbar=errorbar,
        sort=sort, orient=orient, err_style=err_style, err_kws=err_kws,
        legend=legend,
    )

    p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
    p.map_size(sizes=sizes, order=size_order, norm=size_norm)
    p.map_style(markers=markers, dashes=dashes, order=style_order)

    if ax is None:
        ax = plt.gca()

    if "style" not in p.variables and not {"ls", "linestyle"} & set(kwargs):  # XXX
        kwargs["dashes"] = "" if dashes is None or isinstance(dashes, bool) else dashes

    if not p.has_xy_data:
        return ax

    p._attach(ax)

    # Other functions have color as an explicit param,
    # and we should probably do that here too
    color = kwargs.pop("color", kwargs.pop("c", None))
    kwargs["color"] = _default_color(ax.plot, hue, color, kwargs)

    p.plot(ax, kwargs)
    return ax


lineplot.__doc__ = """\
Draw a line plot with possibility of several semantic groupings.

{narrative.main_api}

{narrative.relational_semantic}

By default, the plot aggregates over multiple `y` values at each value of
`x` and shows an estimate of the central tendency and a confidence
interval for that estimate.

Parameters
----------
{params.core.data}
{params.core.xy}
hue : vector or key in `data`
    Grouping variable that will produce lines with different colors.
    Can be either categorical or numeric, although color mapping will
    behave differently in latter case.
size : vector or key in `data`
    Grouping variable that will produce lines with different widths.
    Can be either categorical or numeric, although size mapping will
    behave differently in latter case.
style : vector or key in `data`
    Grouping variable that will produce lines with different dashes
    and/or markers. Can have a numeric dtype but will always be treated
    as categorical.
{params.rel.units}
weights : vector or key in `data`
    Data values or column used to compute weighted estimation.
    Note that use of weights currently limits the choice of statistics
    to a 'mean' estimator and 'ci' errorbar.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
{params.rel.sizes}
{params.rel.size_order}
{params.rel.size_norm}
{params.rel.dashes}
{params.rel.markers}
{params.rel.style_order}
{params.rel.estimator}
{params.stat.errorbar}
{params.rel.n_boot}
{params.rel.seed}
orient : "x" or "y"
    Dimension along which the data are sorted / aggregated. Equivalently,
    the "independent variable" of the resulting function.
sort : boolean
    If True, the data will be sorted by the x and y variables, otherwise
    lines will connect points in the order they appear in the dataset.
err_style : "band" or "bars"
    Whether to draw the confidence intervals with translucent error bands
    or discrete error bars.
err_kws : dict of keyword arguments
    Additional parameters to control the aesthetics of the error bars. The
    kwargs are passed either to :meth:`matplotlib.axes.Axes.fill_between`
    or :meth:`matplotlib.axes.Axes.errorbar`, depending on `err_style`.
{params.rel.legend}
{params.rel.ci}
{params.core.ax}
kwargs : key, value mappings
    Other keyword arguments are passed down to
    :meth:`matplotlib.axes.Axes.plot`.

Returns
-------
{returns.ax}

See Also
--------
{seealso.scatterplot}
{seealso.pointplot}

Examples
--------

.. include:: ../docstrings/lineplot.rst

""".format(
    narrative=_relational_narrative,
    params=_param_docs,
    returns=_core_docs["returns"],
    seealso=_core_docs["seealso"],
)


def scatterplot(
    data=None, *,
    x=None, y=None, hue=None, size=None, style=None,
    palette=None, hue_order=None, hue_norm=None,
    sizes=None, size_order=None, size_norm=None,
    markers=True, style_order=None, legend="auto", ax=None,
    **kwargs
):

    p = _ScatterPlotter(
        data=data,
        variables=dict(x=x, y=y, hue=hue, size=size, style=style),
        legend=legend
    )

    p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
    p.map_size(sizes=sizes, order=size_order, norm=size_norm)
    p.map_style(markers=markers, order=style_order)

    if ax is None:
        ax = plt.gca()

    if not p.has_xy_data:
        return ax

    p._attach(ax)

    color = kwargs.pop("color", None)
    kwargs["color"] = _default_color(ax.scatter, hue, color, kwargs)

    p.plot(ax, kwargs)

    return ax


scatterplot.__doc__ = """\
Draw a scatter plot with possibility of several semantic groupings.

{narrative.main_api}

{narrative.relational_semantic}

Parameters
----------
{params.core.data}
{params.core.xy}
hue : vector or key in `data`
    Grouping variable that will produce points with different colors.
    Can be either categorical or numeric, although color mapping will
    behave differently in latter case.
size : vector or key in `data`
    Grouping variable that will produce points with different sizes.
    Can be either categorical or numeric, although size mapping will
    behave differently in latter case.
style : vector or key in `data`
    Grouping variable that will produce points with different markers.
    Can have a numeric dtype but will always be treated as categorical.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
{params.rel.sizes}
{params.rel.size_order}
{params.rel.size_norm}
{params.rel.markers}
{params.rel.style_order}
{params.rel.legend}
{params.core.ax}
kwargs : key, value mappings
    Other keyword arguments are passed down to
    :meth:`matplotlib.axes.Axes.scatter`.

Returns
-------
{returns.ax}

See Also
--------
{seealso.lineplot}
{seealso.stripplot}
{seealso.swarmplot}

Examples
--------

.. include:: ../docstrings/scatterplot.rst

""".format(
    narrative=_relational_narrative,
    params=_param_docs,
    returns=_core_docs["returns"],
    seealso=_core_docs["seealso"],
)


def relplot(
    data=None, *,
    x=None, y=None, hue=None, size=None, style=None, units=None, weights=None,
    row=None, col=None, col_wrap=None, row_order=None, col_order=None,
    palette=None, hue_order=None, hue_norm=None,
    sizes=None, size_order=None, size_norm=None,
    markers=None, dashes=None, style_order=None,
    legend="auto", kind="scatter", height=5, aspect=1, facet_kws=None,
    **kwargs
):

    if kind == "scatter":

        Plotter = _ScatterPlotter
        func = scatterplot
        markers = True if markers is None else markers

    elif kind == "line":

        Plotter = _LinePlotter
        func = lineplot
        dashes = True if dashes is None else dashes

    else:
        err = f"Plot kind {kind} not recognized"
        raise ValueError(err)

    # Check for attempt to plot onto specific axes and warn
    if "ax" in kwargs:
        msg = (
            "relplot is a figure-level function and does not accept "
            "the `ax` parameter. You may wish to try {}".format(kind + "plot")
        )
        warnings.warn(msg, UserWarning)
        kwargs.pop("ax")

    # Use the full dataset to map the semantics
    variables = dict(x=x, y=y, hue=hue, size=size, style=style)
    if kind == "line":
        variables["units"] = units
        variables["weight"] = weights
    else:
        if units is not None:
            msg = "The `units` parameter has no effect with kind='scatter'."
            warnings.warn(msg, stacklevel=2)
        if weights is not None:
            msg = "The `weights` parameter has no effect with kind='scatter'."
            warnings.warn(msg, stacklevel=2)
    p = Plotter(
        data=data,
        variables=variables,
        legend=legend,
    )
    p.map_hue(palette=palette, order=hue_order, norm=hue_norm)
    p.map_size(sizes=sizes, order=size_order, norm=size_norm)
    p.map_style(markers=markers, dashes=dashes, order=style_order)

    # Extract the semantic mappings
    if "hue" in p.variables:
        palette = p._hue_map.lookup_table
        hue_order = p._hue_map.levels
        hue_norm = p._hue_map.norm
    else:
        palette = hue_order = hue_norm = None

    if "size" in p.variables:
        sizes = p._size_map.lookup_table
        size_order = p._size_map.levels
        size_norm = p._size_map.norm

    if "style" in p.variables:
        style_order = p._style_map.levels
        if markers:
            markers = {k: p._style_map(k, "marker") for k in style_order}
        else:
            markers = None
        if dashes:
            dashes = {k: p._style_map(k, "dashes") for k in style_order}
        else:
            dashes = None
    else:
        markers = dashes = style_order = None

    # Now extract the data that would be used to draw a single plot
    variables = p.variables
    plot_data = p.plot_data

    # Define the common plotting parameters
    plot_kws = dict(
        palette=palette, hue_order=hue_order, hue_norm=hue_norm,
        sizes=sizes, size_order=size_order, size_norm=size_norm,
        markers=markers, dashes=dashes, style_order=style_order,
        legend=False,
    )
    plot_kws.update(kwargs)
    if kind == "scatter":
        plot_kws.pop("dashes")

    # Add the grid semantics onto the plotter
    grid_variables = dict(
        x=x, y=y, row=row, col=col, hue=hue, size=size, style=style,
    )
    if kind == "line":
        grid_variables.update(units=units, weights=weights)
    p.assign_variables(data, grid_variables)

    # Define the named variables for plotting on each facet
    # Rename the variables with a leading underscore to avoid
    # collisions with faceting variable names
    plot_variables = {v: f"_{v}" for v in variables}
    if "weight" in plot_variables:
        plot_variables["weights"] = plot_variables.pop("weight")
    plot_kws.update(plot_variables)

    # Pass the row/col variables to FacetGrid with their original
    # names so that the axes titles render correctly
    for var in ["row", "col"]:
        # Handle faceting variables that lack name information
        if var in p.variables and p.variables[var] is None:
            p.variables[var] = f"_{var}_"
    grid_kws = {v: p.variables.get(v) for v in ["row", "col"]}

    # Rename the columns of the plot_data structure appropriately
    new_cols = plot_variables.copy()
    new_cols.update(grid_kws)
    full_data = p.plot_data.rename(columns=new_cols)

    # Set up the FacetGrid object
    facet_kws = {} if facet_kws is None else facet_kws.copy()
    g = FacetGrid(
        data=full_data.dropna(axis=1, how="all"),
        **grid_kws,
        col_wrap=col_wrap, row_order=row_order, col_order=col_order,
        height=height, aspect=aspect, dropna=False,
        **facet_kws
    )

    # Draw the plot
    g.map_dataframe(func, **plot_kws)

    # Label the axes, using the original variables
    # Pass "" when the variable name is None to overwrite internal variables
    g.set_axis_labels(variables.get("x") or "", variables.get("y") or "")

    if legend:
        # Replace the original plot data so the legend uses numeric data with
        # the correct type, since we force a categorical mapping above.
        p.plot_data = plot_data

        # Handle the additional non-semantic keyword arguments out here.
        # We're selective because some kwargs may be seaborn function specific
        # and not relevant to the matplotlib artists going into the legend.
        # Ideally, we will have a better solution where we don't need to re-make
        # the legend out here and will have parity with the axes-level functions.
        keys = ["c", "color", "alpha", "m", "marker"]
        if kind == "scatter":
            legend_artist = _scatter_legend_artist
            keys += ["s", "facecolor", "fc", "edgecolor", "ec", "linewidth", "lw"]
        else:
            legend_artist = partial(mpl.lines.Line2D, xdata=[], ydata=[])
            keys += [
                "markersize", "ms",
                "markeredgewidth", "mew",
                "markeredgecolor", "mec",
                "linestyle", "ls",
                "linewidth", "lw",
            ]

        common_kws = {k: v for k, v in kwargs.items() if k in keys}
        attrs = {"hue": "color", "style": None}
        if kind == "scatter":
            attrs["size"] = "s"
        elif kind == "line":
            attrs["size"] = "linewidth"
        p.add_legend_data(g.axes.flat[0], legend_artist, common_kws, attrs)
        if p.legend_data:
            g.add_legend(legend_data=p.legend_data,
                         label_order=p.legend_order,
                         title=p.legend_title,
                         adjust_subtitles=True)

    # Rename the columns of the FacetGrid's `data` attribute
    # to match the original column names
    orig_cols = {
        f"_{k}": f"_{k}_" if v is None else v for k, v in variables.items()
    }
    grid_data = g.data.rename(columns=orig_cols)
    if data is not None and (x is not None or y is not None):
        if not isinstance(data, pd.DataFrame):
            data = pd.DataFrame(data)
        g.data = pd.merge(
            data,
            grid_data[grid_data.columns.difference(data.columns)],
            left_index=True,
            right_index=True,
        )
    else:
        g.data = grid_data

    return g


relplot.__doc__ = """\
Figure-level interface for drawing relational plots onto a FacetGrid.

This function provides access to several different axes-level functions
that show the relationship between two variables with semantic mappings
of subsets. The `kind` parameter selects the underlying axes-level
function to use:

- :func:`scatterplot` (with `kind="scatter"`; the default)
- :func:`lineplot` (with `kind="line"`)

Extra keyword arguments are passed to the underlying function, so you
should refer to the documentation for each to see kind-specific options.

{narrative.main_api}

{narrative.relational_semantic}

After plotting, the :class:`FacetGrid` with the plot is returned and can
be used directly to tweak supporting plot details or add other layers.

Parameters
----------
{params.core.data}
{params.core.xy}
hue : vector or key in `data`
    Grouping variable that will produce elements with different colors.
    Can be either categorical or numeric, although color mapping will
    behave differently in latter case.
size : vector or key in `data`
    Grouping variable that will produce elements with different sizes.
    Can be either categorical or numeric, although size mapping will
    behave differently in latter case.
style : vector or key in `data`
    Grouping variable that will produce elements with different styles.
    Can have a numeric dtype but will always be treated as categorical.
{params.rel.units}
weights : vector or key in `data`
    Data values or column used to compute weighted estimation.
    Note that use of weights currently limits the choice of statistics
    to a 'mean' estimator and 'ci' errorbar.
{params.facets.rowcol}
{params.facets.col_wrap}
row_order, col_order : lists of strings
    Order to organize the rows and/or columns of the grid in, otherwise the
    orders are inferred from the data objects.
{params.core.palette}
{params.core.hue_order}
{params.core.hue_norm}
{params.rel.sizes}
{params.rel.size_order}
{params.rel.size_norm}
{params.rel.style_order}
{params.rel.dashes}
{params.rel.markers}
{params.rel.legend}
kind : string
    Kind of plot to draw, corresponding to a seaborn relational plot.
    Options are `"scatter"` or `"line"`.
{params.facets.height}
{params.facets.aspect}
facet_kws : dict
    Dictionary of other keyword arguments to pass to :class:`FacetGrid`.
kwargs : key, value pairings
    Other keyword arguments are passed through to the underlying plotting
    function.

Returns
-------
{returns.facetgrid}

Examples
--------

.. include:: ../docstrings/relplot.rst

""".format(
    narrative=_relational_narrative,
    params=_param_docs,
    returns=_core_docs["returns"],
)


================================================
FILE: seaborn/utils.py
================================================
"""Utility functions, mostly for internal use."""
import os
import inspect
import warnings
import colorsys
from contextlib import contextmanager
from urllib.request import urlopen, urlretrieve
from types import ModuleType

import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.colors import to_rgb
import matplotlib.pyplot as plt
from matplotlib.cbook import normalize_kwargs

from seaborn._core.typing import deprecated
from seaborn.external.version import Version
from seaborn.external.appdirs import user_cache_dir

__all__ = ["desaturate", "saturate", "set_hls_values", "move_legend",
           "despine", "get_dataset_names", "get_data_home", "load_dataset"]

DATASET_SOURCE = "https://raw.githubusercontent.com/mwaskom/seaborn-data/master"
DATASET_NAMES_URL = f"{DATASET_SOURCE}/dataset_names.txt"


def ci_to_errsize(cis, heights):
    """Convert intervals to error arguments relative to plot heights.

    Parameters
    ----------
    cis : 2 x n sequence
        sequence of confidence interval limits
    heights : n sequence
        sequence of plot heights

    Returns
    -------
    errsize : 2 x n array
        sequence of error size relative to height values in correct
        format as argument for plt.bar

    """
    cis = np.atleast_2d(cis).reshape(2, -1)
    heights = np.atleast_1d(heights)
    errsize = []
    for i, (low, high) in enumerate(np.transpose(cis)):
        h = heights[i]
        elow = h - low
        ehigh = high - h
        errsize.append([elow, ehigh])

    errsize = np.asarray(errsize).T
    return errsize


def _draw_figure(fig):
    """Force draw of a matplotlib figure, accounting for back-compat."""
    # See https://github.com/matplotlib/matplotlib/issues/19197 for context
    fig.canvas.draw()
    if fig.stale:
        try:
            fig.draw(fig.canvas.get_renderer())
        except AttributeError:
            pass


def _default_color(method, hue, color, kws, saturation=1):
    """If needed, get a default color by using the matplotlib property cycle."""

    if hue is not None:
        # This warning is probably user-friendly, but it's currently triggered
        # in a FacetGrid context and I don't want to mess with that logic right now
        #  if color is not None:
        #      msg = "`color` is ignored when `hue` is assigned."
        #      warnings.warn(msg)
        return None

    kws = kws.copy()
    kws.pop("label", None)

    if color is not None:
        if saturation < 1:
            color = desaturate(color, saturation)
        return color

    elif method.__name__ == "plot":

        color = normalize_kwargs(kws, mpl.lines.Line2D).get("color")
        scout, = method([], [], scalex=False, scaley=False, color=color)
        color = scout.get_color()
        scout.remove()

    elif method.__name__ == "scatter":

        # Matplotlib will raise if the size of x/y don't match s/c,
        # and the latter might be in the kws dict
        scout_size = max(
            np.atleast_1d(kws.get(key, [])).shape[0]
            for key in ["s", "c", "fc", "facecolor", "facecolors"]
        )
        scout_x = scout_y = np.full(scout_size, np.nan)

        scout = method(scout_x, scout_y, **kws)
        facecolors = scout.get_facecolors()

        if not len(facecolors):
            # Handle bug in matplotlib <= 3.2 (I think)
            # This will limit the ability to use non color= kwargs to specify
            # a color in versions of matplotlib with the bug, but trying to
            # work out what the user wanted by re-implementing the broken logic
            # of inspecting the kwargs is probably too brittle.
            single_color = False
        else:
            single_color = np.unique(facecolors, axis=0).shape[0] == 1

        # Allow the user to specify an array of colors through various kwargs
        if "c" not in kws and single_color:
            color = to_rgb(facecolors[0])

        scout.remove()

    elif method.__name__ == "bar":

        # bar() needs masked, not empty data, to generate a patch
        scout, = method([np.nan], [np.nan], **kws)
        color = to_rgb(scout.get_facecolor())
        scout.remove()
        # Axes.bar adds both a patch and a container
        method.__self__.containers.pop(-1)

    elif method.__name__ == "fill_between":

        kws = normalize_kwargs(kws, mpl.collections.PolyCollection)
        scout = method([], [], **kws)
        facecolor = scout.get_facecolor()
        color = to_rgb(facecolor[0])
        scout.remove()

    if saturation < 1:
        color = desaturate(color, saturation)

    return color


def desaturate(color, prop):
    """Decrease the saturation channel of a color by some percent.

    Parameters
    ----------
    color : matplotlib color
        hex, rgb-tuple, or html color name
    prop : float
        saturation channel of color will be multiplied by this value

    Returns
    -------
    new_color : rgb tuple
        desaturated color code in RGB tuple representation

    """
    # Check inputs
    if not 0 <= prop <= 1:
        raise ValueError("prop must be between 0 and 1")

    # Get rgb tuple rep
    rgb = to_rgb(color)

    # Short circuit to avoid floating point issues
    if prop == 1:
        return rgb

    # Convert to hls
    h, l, s = colorsys.rgb_to_hls(*rgb)

    # Desaturate the saturation channel
    s *= prop

    # Convert back to rgb
    new_color = colorsys.hls_to_rgb(h, l, s)

    return new_color


def saturate(color):
    """Return a fully saturated color with the same hue.

    Parameters
    ----------
    color : matplotlib color
        hex, rgb-tuple, or html color name

    Returns
    -------
    new_color : rgb tuple
        saturated color code in RGB tuple representation

    """
    return set_hls_values(color, s=1)


def set_hls_values(color, h=None, l=None, s=None):  # noqa
    """Independently manipulate the h, l, or s channels of a color.

    Parameters
    ----------
    color : matplotlib color
        hex, rgb-tuple, or html color name
    h, l, s : floats between 0 and 1, or None
        new values for each channel in hls space

    Returns
    -------
    new_color : rgb tuple
        new color code in RGB tuple representation

    """
    # Get an RGB tuple representation
    rgb = to_rgb(color)
    vals = list(colorsys.rgb_to_hls(*rgb))
    for i, val in enumerate([h, l, s]):
        if val is not None:
            vals[i] = val

    rgb = colorsys.hls_to_rgb(*vals)
    return rgb


def axlabel(xlabel, ylabel, **kwargs):
    """Grab current axis and label it.

    DEPRECATED: will be removed in a future version.

    """
    msg = "This function is deprecated and will be removed in a future version"
    warnings.warn(msg, FutureWarning)
    ax = plt.gca()
    ax.set_xlabel(xlabel, **kwargs)
    ax.set_ylabel(ylabel, **kwargs)


def remove_na(vector):
    """Helper method for removing null values from data vectors.

    Parameters
    ----------
    vector : vector object
        Must implement boolean masking with [] subscript syntax.

    Returns
    -------
    clean_clean : same type as ``vector``
        Vector of data with null values removed. May be a copy or a view.

    """
    return vector[pd.notnull(vector)]


def get_color_cycle():
    """Return the list of colors in the current matplotlib color cycle

    Parameters
    ----------
    None

    Returns
    -------
    colors : list
        List of matplotlib colors in the current cycle, or dark gray if
        the current color cycle is empty.
    """
    cycler = mpl.rcParams['axes.prop_cycle']
    return cycler.by_key()['color'] if 'color' in cycler.keys else [".15"]


def despine(fig=None, ax=None, top=True, right=True, left=False,
            bottom=False, offset=None, trim=False):
    """Remove the top and right spines from plot(s).

    fig : matplotlib figure, optional
        Figure to despine all axes of, defaults to the current figure.
    ax : matplotlib axes, optional
        Specific axes object to despine. Ignored if fig is provided.
    top, right, left, bottom : boolean, optional
        If True, remove that spine.
    offset : int or dict, optional
        Absolute distance, in points, spines should be moved away
        from the axes (negative values move spines inward). A single value
        applies to all spines; a dict can be used to set offset values per
        side.
    trim : bool, optional
        If True, limit spines to the smallest and largest major tick
        on each non-despined axis.

    Returns
    -------
    None

    """
    # Get references to the axes we want
    if fig is None and ax is None:
        axes = plt.gcf().axes
    elif fig is not None:
        axes = fig.axes
    elif ax is not None:
        axes = [ax]

    for ax_i in axes:
        for side in ["top", "right", "left", "bottom"]:
            # Toggle the spine objects
            is_visible = not locals()[side]
            ax_i.spines[side].set_visible(is_visible)
            if offset is not None and is_visible:
                try:
                    val = offset.get(side, 0)
                except AttributeError:
                    val = offset
                ax_i.spines[side].set_position(('outward', val))

        # Potentially move the ticks
        if left and not right:
            maj_on = any(
                t.tick1line.get_visible()
                for t in ax_i.yaxis.majorTicks
            )
            min_on = any(
                t.tick1line.get_visible()
                for t in ax_i.yaxis.minorTicks
            )
            ax_i.yaxis.set_ticks_position("right")
            for t in ax_i.yaxis.majorTicks:
                t.tick2line.set_visible(maj_on)
            for t in ax_i.yaxis.minorTicks:
                t.tick2line.set_visible(min_on)

        if bottom and not top:
            maj_on = any(
                t.tick1line.get_visible()
                for t in ax_i.xaxis.majorTicks
            )
            min_on = any(
                t.tick1line.get_visible()
                for t in ax_i.xaxis.minorTicks
            )
            ax_i.xaxis.set_ticks_position("top")
            for t in ax_i.xaxis.majorTicks:
                t.tick2line.set_visible(maj_on)
            for t in ax_i.xaxis.minorTicks:
                t.tick2line.set_visible(min_on)

        if trim:
            # clip off the parts of the spines that extend past major ticks
            xticks = np.asarray(ax_i.get_xticks())
            if xticks.size:
                firsttick = np.compress(xticks >= min(ax_i.get_xlim()),
                                        xticks)[0]
                lasttick = np.compress(xticks <= max(ax_i.get_xlim()),
                                       xticks)[-1]
                ax_i.spines['bottom'].set_bounds(firsttick, lasttick)
                ax_i.spines['top'].set_bounds(firsttick, lasttick)
                newticks = xticks.compress(xticks <= lasttick)
                newticks = newticks.compress(newticks >= firsttick)
                ax_i.set_xticks(newticks)

            yticks = np.asarray(ax_i.get_yticks())
            if yticks.size:
                firsttick = np.compress(yticks >= min(ax_i.get_ylim()),
                                        yticks)[0]
                lasttick = np.compress(yticks <= max(ax_i.get_ylim()),
                                       yticks)[-1]
                ax_i.spines['left'].set_bounds(firsttick, lasttick)
                ax_i.spines['right'].set_bounds(firsttick, lasttick)
                newticks = yticks.compress(yticks <= lasttick)
                newticks = newticks.compress(newticks >= firsttick)
                ax_i.set_yticks(newticks)


def move_legend(obj, loc, **kwargs):
    """
    Recreate a plot's legend at a new location.

    The name is a slight misnomer. Matplotlib legends do not expose public
    control over their position parameters. So this function creates a new legend,
    copying over the data from the original object, which is then removed.

    Parameters
    ----------
    obj : the object with the plot
        This argument can be either a seaborn or matplotlib object:

        - :class:`seaborn.FacetGrid` or :class:`seaborn.PairGrid`
        - :class:`matplotlib.axes.Axes` or :class:`matplotlib.figure.Figure`

    loc : str or int
        Location argument, as in :meth:`matplotlib.axes.Axes.legend`.

    kwargs
        Other keyword arguments are passed to :meth:`matplotlib.axes.Axes.legend`.

    Examples
    --------

    .. include:: ../docstrings/move_legend.rst

    """
    # This is a somewhat hackish solution that will hopefully be obviated by
    # upstream improvements to matplotlib legends that make them easier to
    # modify after creation.

    from seaborn.axisgrid import Grid  # Avoid circular import

    # Locate the legend object and a method to recreate the legend
    if isinstance(obj, Grid):
        old_legend = obj.legend
        legend_func = obj.figure.legend
    elif isinstance(obj, mpl.axes.Axes):
        old_legend = obj.legend_
        legend_func = obj.legend
    elif isinstance(obj, mpl.figure.Figure):
        if obj.legends:
            old_legend = obj.legends[-1]
        else:
            old_legend = None
        legend_func = obj.legend
    else:
        err = "`obj` must be a seaborn Grid or matplotlib Axes or Figure instance."
        raise TypeError(err)

    if old_legend is None:
        err = f"{obj} has no legend attached."
        raise ValueError(err)

    # Extract the components of the legend we need to reuse
    # Import here to avoid a circular import
    from seaborn._compat import get_legend_handles
    handles = get_legend_handles(old_legend)
    labels = [t.get_text() for t in old_legend.get_texts()]

    # Handle the case where the user is trying to override the labels
    if (new_labels := kwargs.pop("labels", None)) is not None:
        if len(new_labels) != len(labels):
            err = "Length of new labels does not match existing legend."
            raise ValueError(err)
        labels = new_labels

    # Extract legend properties that can be passed to the recreation method
    # (Vexingly, these don't all round-trip)
    legend_kws = inspect.signature(mpl.legend.Legend).parameters
    props = {k: v for k, v in old_legend.properties().items() if k in legend_kws}

    # Delegate default bbox_to_anchor rules to matplotlib
    props.pop("bbox_to_anchor")

    # Try to propagate the existing title and font properties; respect new ones too
    title = props.pop("title")
    if "title" in kwargs:
        title.set_text(kwargs.pop("title"))
    title_kwargs = {k: v for k, v in kwargs.items() if k.startswith("title_")}
    for key, val in title_kwargs.items():
        title.set(**{key[6:]: val})
        kwargs.pop(key)

    # Try to respect the frame visibility
    kwargs.setdefault("frameon", old_legend.legendPatch.get_visible())

    # Remove the old legend and create the new one
    props.update(kwargs)
    old_legend.remove()
    new_legend = legend_func(handles, labels, loc=loc, **props)
    new_legend.set_title(title.get_text(), title.get_fontproperties())

    # Let the Grid object continue to track the correct legend object
    if isinstance(obj, Grid):
        obj._legend = new_legend


def _kde_support(data, bw, gridsize, cut, clip):
    """Establish support for a kernel density estimate."""
    support_min = max(data.min() - bw * cut, clip[0])
    support_max = min(data.max() + bw * cut, clip[1])
    support = np.linspace(support_min, support_max, gridsize)

    return support


def ci(a, which=95, axis=None):
    """Return a percentile range from an array of values."""
    p = 50 - which / 2, 50 + which / 2
    return np.nanpercentile(a, p, axis)


def get_dataset_names():
    """Report available example datasets, useful for reporting issues.

    Requires an internet connection.

    """
    with urlopen(DATASET_NAMES_URL) as resp:
        txt = resp.read()

    dataset_names = [name.strip() for name in txt.decode().split("\n")]
    return list(filter(None, dataset_names))


def get_data_home(data_home=None):
    """Return a path to the cache directory for example datasets.

    This directory is used by :func:`load_dataset`.

    If the ``data_home`` argument is not provided, it will use a directory
    specified by the `SEABORN_DATA` environment variable (if it exists)
    or otherwise default to an OS-appropriate user cache location.

    """
    if data_home is None:
        data_home = os.environ.get("SEABORN_DATA", user_cache_dir("seaborn"))
    data_home = os.path.expanduser(data_home)
    if not os.path.exists(data_home):
        os.makedirs(data_home)
    return data_home


def load_dataset(name, cache=True, data_home=None, **kws):
    """Load an example dataset from the online repository (requires internet).

    This function provides quick access to a small number of example datasets
    that are useful for documenting seaborn or generating reproducible examples
    for bug reports. It is not necessary for normal usage.

    Note that some of the datasets have a small amount of preprocessing applied
    to define a proper ordering for categorical variables.

    Use :func:`get_dataset_names` to see a list of available datasets.

    Parameters
    ----------
    name : str
        Name of the dataset (``{name}.csv`` on
        https://github.com/mwaskom/seaborn-data).
    cache : boolean, optional
        If True, try to load from the local cache first, and save to the cache
        if a download is required.
    data_home : string, optional
        The directory in which to cache data; see :func:`get_data_home`.
    kws : keys and values, optional
        Additional keyword arguments are passed to passed through to
        :func:`pandas.read_csv`.

    Returns
    -------
    df : :class:`pandas.DataFrame`
        Tabular data, possibly with some preprocessing applied.

    """
    # A common beginner mistake is to assume that one's personal data needs
    # to be passed through this function to be usable with seaborn.
    # Let's provide a more helpful error than you would otherwise get.
    if isinstance(name, pd.DataFrame):
        err = (
            "This function accepts only strings (the name of an example dataset). "
            "You passed a pandas DataFrame. If you have your own dataset, "
            "it is not necessary to use this function before plotting."
        )
        raise TypeError(err)

    url = f"{DATASET_SOURCE}/{name}.csv"

    if cache:
        cache_path = os.path.join(get_data_home(data_home), os.path.basename(url))
        if not os.path.exists(cache_path):
            if name not in get_dataset_names():
                raise ValueError(f"'{name}' is not one of the example datasets.")
            urlretrieve(url, cache_path)
        full_path = cache_path
    else:
        full_path = url

    df = pd.read_csv(full_path, **kws)

    if df.iloc[-1].isnull().all():
        df = df.iloc[:-1]

    # Set some columns as a categorical type with ordered levels

    if name == "tips":
        df["day"] = pd.Categorical(df["day"], ["Thur", "Fri", "Sat", "Sun"])
        df["sex"] = pd.Categorical(df["sex"], ["Male", "Female"])
        df["time"] = pd.Categorical(df["time"], ["Lunch", "Dinner"])
        df["smoker"] = pd.Categorical(df["smoker"], ["Yes", "No"])

    elif name == "flights":
        months = df["month"].str[:3]
        df["month"] = pd.Categorical(months, months.unique())

    elif name == "exercise":
        df["time"] = pd.Categorical(df["time"], ["1 min", "15 min", "30 min"])
        df["kind"] = pd.Categorical(df["kind"], ["rest", "walking", "running"])
        df["diet"] = pd.Categorical(df["diet"], ["no fat", "low fat"])

    elif name == "titanic":
        df["class"] = pd.Categorical(df["class"], ["First", "Second", "Third"])
        df["deck"] = pd.Categorical(df["deck"], list("ABCDEFG"))

    elif name == "penguins":
        df["sex"] = df["sex"].str.title()

    elif name == "diamonds":
        df["color"] = pd.Categorical(
            df["color"], ["D", "E", "F", "G", "H", "I", "J"],
        )
        df["clarity"] = pd.Categorical(
            df["clarity"], ["IF", "VVS1", "VVS2", "VS1", "VS2", "SI1", "SI2", "I1"],
        )
        df["cut"] = pd.Categorical(
            df["cut"], ["Ideal", "Premium", "Very Good", "Good", "Fair"],
        )

    elif name == "taxis":
        df["pickup"] = pd.to_datetime(df["pickup"])
        df["dropoff"] = pd.to_datetime(df["dropoff"])

    elif name == "seaice":
        df["Date"] = pd.to_datetime(df["Date"])

    elif name == "dowjones":
        df["Date"] = pd.to_datetime(df["Date"])

    return df


def axis_ticklabels_overlap(labels):
    """Return a boolean for whether the list of ticklabels have overlaps.

    Parameters
    ----------
    labels : list of matplotlib ticklabels

    Returns
    -------
    overlap : boolean
        True if any of the labels overlap.

    """
    if not labels:
        return False
    try:
        bboxes = [l.get_window_extent() for l in labels]
        overlaps = [b.count_overlaps(bboxes) for b in bboxes]
        return max(overlaps) > 1
    except RuntimeError:
        # Issue on macos backend raises an error in the above code
        return False


def axes_ticklabels_overlap(ax):
    """Return booleans for whether the x and y ticklabels on an Axes overlap.

    Parameters
    ----------
    ax : matplotlib Axes

    Returns
    -------
    x_overlap, y_overlap : booleans
        True when the labels on that axis overlap.

    """
    return (axis_ticklabels_overlap(ax.get_xticklabels()),
            axis_ticklabels_overlap(ax.get_yticklabels()))


def locator_to_legend_entries(locator, limits, dtype):
    """Return levels and formatted levels for brief numeric legends."""
    raw_levels = locator.tick_values(*limits).astype(dtype)

    # The locator can return ticks outside the limits, clip them here
    raw_levels = [l for l in raw_levels if l >= limits[0] and l <= limits[1]]

    class dummy_axis:
        def get_view_interval(self):
            return limits

    if isinstance(locator, mpl.ticker.LogLocator):
        formatter = mpl.ticker.LogFormatter()
    else:
        formatter = mpl.ticker.ScalarFormatter()
        # Avoid having an offset/scientific notation which we don't currently
        # have any way of representing in the legend
        formatter.set_useOffset(False)
        formatter.set_scientific(False)
    formatter.axis = dummy_axis()

    formatted_levels = formatter.format_ticks(raw_levels)

    return raw_levels, formatted_levels


def relative_luminance(color):
    """Calculate the relative luminance of a color according to W3C standards

    Parameters
    ----------
    color : matplotlib color or sequence of matplotlib colors
        Hex code, rgb-tuple, or html color name.

    Returns
    -------
    luminance : float(s) between 0 and 1

    """
    rgb = mpl.colors.colorConverter.to_rgba_array(color)[:, :3]
    rgb = np.where(rgb <= .03928, rgb / 12.92, ((rgb + .055) / 1.055) ** 2.4)
    lum = rgb.dot([.2126, .7152, .0722])
    try:
        return lum.item()
    except ValueError:
        return lum


def to_utf8(obj):
    """Return a string representing a Python object.

    Strings (i.e. type ``str``) are returned unchanged.

    Byte strings (i.e. type ``bytes``) are returned as UTF-8-decoded strings.

    For other objects, the method ``__str__()`` is called, and the result is
    returned as a string.

    Parameters
    ----------
    obj : object
        Any Python object

    Returns
    -------
    s : str
        UTF-8-decoded string representation of ``obj``

    """
    if isinstance(obj, str):
        return obj
    try:
        return obj.decode(encoding="utf-8")
    except AttributeError:  # obj is not bytes-like
        return str(obj)


def _check_argument(param, options, value, prefix=False):
    """Raise if value for param is not in options."""
    if prefix and value is not None:
        failure = not any(value.startswith(p) for p in options if isinstance(p, str))
    else:
        failure = value not in options
    if failure:
        raise ValueError(
            f"The value for `{param}` must be one of {options}, "
            f"but {repr(value)} was passed."
        )
    return value


def _assign_default_kwargs(kws, call_func, source_func):
    """Assign default kwargs for call_func using values from source_func."""
    # This exists so that axes-level functions and figure-level functions can
    # both call a Plotter method while having the default kwargs be defined in
    # the signature of the axes-level function.
    # An alternative would be to have a decorator on the method that sets its
    # defaults based on those defined in the axes-level function.
    # Then the figure-level function would not need to worry about defaults.
    # I am not sure which is better.
    needed = inspect.signature(call_func).parameters
    defaults = inspect.signature(source_func).parameters

    for param in needed:
        if param in defaults and param not in kws:
            kws[param] = defaults[param].default

    return kws


def adjust_legend_subtitles(legend):
    """
    Make invisible-handle "subtitles" entries look more like titles.

    Note: This function is not part of the public API and may be changed or removed.

    """
    # Legend title not in rcParams until 3.0
    font_size = plt.rcParams.get("legend.title_fontsize", None)
    hpackers = legend.findobj(mpl.offsetbox.VPacker)[0].get_children()
    for hpack in hpackers:
        draw_area, text_area = hpack.get_children()
        handles = draw_area.get_children()
        if not all(artist.get_visible() for artist in handles):
            draw_area.set_width(0)
            for text in text_area.get_children():
                if font_size is not None:
                    text.set_size(font_size)


def _deprecate_ci(errorbar, ci):
    """
    Warn on usage of ci= and convert to appropriate errorbar= arg.

    ci was deprecated when errorbar was added in 0.12. It should not be removed
    completely for some time, but it can be moved out of function definitions
    (and extracted from kwargs) after one cycle.

    """
    if ci is not deprecated and ci != "deprecated":
        if ci is None:
            errorbar = None
        elif ci == "sd":
            errorbar = "sd"
        else:
            errorbar = ("ci", ci)
        msg = (
            "\n\nThe `ci` parameter is deprecated. "
            f"Use `errorbar={repr(errorbar)}` for the same effect.\n"
        )
        warnings.warn(msg, FutureWarning, stacklevel=3)

    return errorbar


def _get_transform_functions(ax, axis):
    """Return the forward and inverse transforms for a given axis."""
    axis_obj = getattr(ax, f"{axis}axis")
    transform = axis_obj.get_transform()
    return transform.transform, transform.inverted().transform


@contextmanager
def _disable_autolayout():
    """Context manager for preventing rc-controlled auto-layout behavior."""
    # This is a workaround for an issue in matplotlib, for details see
    # https://github.com/mwaskom/seaborn/issues/2914
    # The only affect of this rcParam is to set the default value for
    # layout= in plt.figure, so we could just do that instead.
    # But then we would need to own the complexity of the transition
    # from tight_layout=True -> layout="tight". This seems easier,
    # but can be removed when (if) that is simpler on the matplotlib side,
    # or if the layout algorithms are improved to handle figure legends.
    orig_val = mpl.rcParams["figure.autolayout"]
    try:
        mpl.rcParams["figure.autolayout"] = False
        yield
    finally:
        mpl.rcParams["figure.autolayout"] = orig_val


def _version_predates(lib: ModuleType, version: str) -> bool:
    """Helper function for checking version compatibility."""
    return Version(lib.__version__) < Version(version)


def _scatter_legend_artist(**kws):

    kws = normalize_kwargs(kws, mpl.collections.PathCollection)

    edgecolor = kws.pop("edgecolor", None)
    rc = mpl.rcParams
    line_kws = {
        "linestyle": "",
        "marker": kws.pop("marker", "o"),
        "markersize": np.sqrt(kws.pop("s", rc["lines.markersize"] ** 2)),
        "markerfacecolor": kws.pop("facecolor", kws.get("color")),
        "markeredgewidth": kws.pop("linewidth", 0),
        **kws,
    }

    if edgecolor is not None:
        if edgecolor == "face":
            line_kws["markeredgecolor"] = line_kws["markerfacecolor"]
        else:
            line_kws["markeredgecolor"] = edgecolor

    return mpl.lines.Line2D([], [], **line_kws)


def _get_patch_legend_artist(fill):

    def legend_artist(**kws):

        color = kws.pop("color", None)
        if color is not None:
            if fill:
                kws["facecolor"] = color
            else:
                kws["edgecolor"] = color
                kws["facecolor"] = "none"

        return mpl.patches.Rectangle((0, 0), 0, 0, **kws)

    return legend_artist


================================================
FILE: seaborn/widgets.py
================================================
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import LinearSegmentedColormap

try:
    from ipywidgets import interact, FloatSlider, IntSlider
except ImportError:
    def interact(f):
        msg = "Interactive palettes require `ipywidgets`, which is not installed."
        raise ImportError(msg)

from .miscplot import palplot
from .palettes import (color_palette, dark_palette, light_palette,
                       diverging_palette, cubehelix_palette)


__all__ = ["choose_colorbrewer_palette", "choose_cubehelix_palette",
           "choose_dark_palette", "choose_light_palette",
           "choose_diverging_palette"]


def _init_mutable_colormap():
    """Create a matplotlib colormap that will be updated by the widgets."""
    greys = color_palette("Greys", 256)
    cmap = LinearSegmentedColormap.from_list("interactive", greys)
    cmap._init()
    cmap._set_extremes()
    return cmap


def _update_lut(cmap, colors):
    """Change the LUT values in a matplotlib colormap in-place."""
    cmap._lut[:256] = colors
    cmap._set_extremes()


def _show_cmap(cmap):
    """Show a continuous matplotlib colormap."""
    from .rcmod import axes_style  # Avoid circular import
    with axes_style("white"):
        f, ax = plt.subplots(figsize=(8.25, .75))
    ax.set(xticks=[], yticks=[])
    x = np.linspace(0, 1, 256)[np.newaxis, :]
    ax.pcolormesh(x, cmap=cmap)


def choose_colorbrewer_palette(data_type, as_cmap=False):
    """Select a palette from the ColorBrewer set.

    These palettes are built into matplotlib and can be used by name in
    many seaborn functions, or by passing the object returned by this function.

    Parameters
    ----------
    data_type : {'sequential', 'diverging', 'qualitative'}
        This describes the kind of data you want to visualize. See the seaborn
        color palette docs for more information about how to choose this value.
        Note that you can pass substrings (e.g. 'q' for 'qualitative.

    as_cmap : bool
        If True, the return value is a matplotlib colormap rather than a
        list of discrete colors.

    Returns
    -------
    pal or cmap : list of colors or matplotlib colormap
        Object that can be passed to plotting functions.

    See Also
    --------
    dark_palette : Create a sequential palette with dark low values.
    light_palette : Create a sequential palette with bright low values.
    diverging_palette : Create a diverging palette from selected colors.
    cubehelix_palette : Create a sequential palette or colormap using the
                        cubehelix system.


    """
    if data_type.startswith("q") and as_cmap:
        raise ValueError("Qualitative palettes cannot be colormaps.")

    pal = []
    if as_cmap:
        cmap = _init_mutable_colormap()

    if data_type.startswith("s"):
        opts = ["Greys", "Reds", "Greens", "Blues", "Oranges", "Purples",
                "BuGn", "BuPu", "GnBu", "OrRd", "PuBu", "PuRd", "RdPu", "YlGn",
                "PuBuGn", "YlGnBu", "YlOrBr", "YlOrRd"]
        variants = ["regular", "reverse", "dark"]

        @interact
        def choose_sequential(name=opts, n=(2, 18),
                              desat=FloatSlider(min=0, max=1, value=1),
                              variant=variants):
            if variant == "reverse":
                name += "_r"
            elif variant == "dark":
                name += "_d"

            if as_cmap:
                colors = color_palette(name, 256, desat)
                _update_lut(cmap, np.c_[colors, np.ones(256)])
                _show_cmap(cmap)
            else:
                pal[:] = color_palette(name, n, desat)
                palplot(pal)

    elif data_type.startswith("d"):
        opts = ["RdBu", "RdGy", "PRGn", "PiYG", "BrBG",
                "RdYlBu", "RdYlGn", "Spectral"]
        variants = ["regular", "reverse"]

        @interact
        def choose_diverging(name=opts, n=(2, 16),
                             desat=FloatSlider(min=0, max=1, value=1),
                             variant=variants):
            if variant == "reverse":
                name += "_r"
            if as_cmap:
                colors = color_palette(name, 256, desat)
                _update_lut(cmap, np.c_[colors, np.ones(256)])
                _show_cmap(cmap)
            else:
                pal[:] = color_palette(name, n, desat)
                palplot(pal)

    elif data_type.startswith("q"):
        opts = ["Set1", "Set2", "Set3", "Paired", "Accent",
                "Pastel1", "Pastel2", "Dark2"]

        @interact
        def choose_qualitative(name=opts, n=(2, 16),
                               desat=FloatSlider(min=0, max=1, value=1)):
            pal[:] = color_palette(name, n, desat)
            palplot(pal)

    if as_cmap:
        return cmap
    return pal


def choose_dark_palette(input="husl", as_cmap=False):
    """Launch an interactive widget to create a dark sequential palette.

    This corresponds with the :func:`dark_palette` function. This kind
    of palette is good for data that range between relatively uninteresting
    low values and interesting high values.

    Requires IPython 2+ and must be used in the notebook.

    Parameters
    ----------
    input : {'husl', 'hls', 'rgb'}
        Color space for defining the seed value. Note that the default is
        different than the default input for :func:`dark_palette`.
    as_cmap : bool
        If True, the return value is a matplotlib colormap rather than a
        list of discrete colors.

    Returns
    -------
    pal or cmap : list of colors or matplotlib colormap
        Object that can be passed to plotting functions.

    See Also
    --------
    dark_palette : Create a sequential palette with dark low values.
    light_palette : Create a sequential palette with bright low values.
    cubehelix_palette : Create a sequential palette or colormap using the
                        cubehelix system.

    """
    pal = []
    if as_cmap:
        cmap = _init_mutable_colormap()

    if input == "rgb":
        @interact
        def choose_dark_palette_rgb(r=(0., 1.),
                                    g=(0., 1.),
                                    b=(0., 1.),
                                    n=(3, 17)):
            color = r, g, b
            if as_cmap:
                colors = dark_palette(color, 256, input="rgb")
                _update_lut(cmap, colors)
                _show_cmap(cmap)
            else:
                pal[:] = dark_palette(color, n, input="rgb")
                palplot(pal)

    elif input == "hls":
        @interact
        def choose_dark_palette_hls(h=(0., 1.),
                                    l=(0., 1.),  # noqa: E741
                                    s=(0., 1.),
                                    n=(3, 17)):
            color = h, l, s
            if as_cmap:
                colors = dark_palette(color, 256, input="hls")
                _update_lut(cmap, colors)
                _show_cmap(cmap)
            else:
                pal[:] = dark_palette(color, n, input="hls")
                palplot(pal)

    elif input == "husl":
        @interact
        def choose_dark_palette_husl(h=(0, 359),
                                     s=(0, 99),
                                     l=(0, 99),  # noqa: E741
                                     n=(3, 17)):
            color = h, s, l
            if as_cmap:
                colors = dark_palette(color, 256, input="husl")
                _update_lut(cmap, colors)
                _show_cmap(cmap)
            else:
                pal[:] = dark_palette(color, n, input="husl")
                palplot(pal)

    if as_cmap:
        return cmap
    return pal


def choose_light_palette(input="husl", as_cmap=False):
    """Launch an interactive widget to create a light sequential palette.

    This corresponds with the :func:`light_palette` function. This kind
    of palette is good for data that range between relatively uninteresting
    low values and interesting high values.

    Requires IPython 2+ and must be used in the notebook.

    Parameters
    ----------
    input : {'husl', 'hls', 'rgb'}
        Color space for defining the seed value. Note that the default is
        different than the default input for :func:`light_palette`.
    as_cmap : bool
        If True, the return value is a matplotlib colormap rather than a
        list of discrete colors.

    Returns
    -------
    pal or cmap : list of colors or matplotlib colormap
        Object that can be passed to plotting functions.

    See Also
    --------
    light_palette : Create a sequential palette with bright low values.
    dark_palette : Create a sequential palette with dark low values.
    cubehelix_palette : Create a sequential palette or colormap using the
                        cubehelix system.

    """
    pal = []
    if as_cmap:
        cmap = _init_mutable_colormap()

    if input == "rgb":
        @interact
        def choose_light_palette_rgb(r=(0., 1.),
                                     g=(0., 1.),
                                     b=(0., 1.),
                                     n=(3, 17)):
            color = r, g, b
            if as_cmap:
                colors = light_palette(color, 256, input="rgb")
                _update_lut(cmap, colors)
                _show_cmap(cmap)
            else:
                pal[:] = light_palette(color, n, input="rgb")
                palplot(pal)

    elif input == "hls":
        @interact
        def choose_light_palette_hls(h=(0., 1.),
                                     l=(0., 1.),  # noqa: E741
                                     s=(0., 1.),
                                     n=(3, 17)):
            color = h, l, s
            if as_cmap:
                colors = light_palette(color, 256, input="hls")
                _update_lut(cmap, colors)
                _show_cmap(cmap)
            else:
                pal[:] = light_palette(color, n, input="hls")
                palplot(pal)

    elif input == "husl":
        @interact
        def choose_light_palette_husl(h=(0, 359),
                                      s=(0, 99),
                                      l=(0, 99),  # noqa: E741
                                      n=(3, 17)):
            color = h, s, l
            if as_cmap:
                colors = light_palette(color, 256, input="husl")
                _update_lut(cmap, colors)
                _show_cmap(cmap)
            else:
                pal[:] = light_palette(color, n, input="husl")
                palplot(pal)

    if as_cmap:
        return cmap
    return pal


def choose_diverging_palette(as_cmap=False):
    """Launch an interactive widget to choose a diverging color palette.

    This corresponds with the :func:`diverging_palette` function. This kind
    of palette is good for data that range between interesting low values
    and interesting high values with a meaningful midpoint. (For example,
    change scores relative to some baseline value).

    Requires IPython 2+ and must be used in the notebook.

    Parameters
    ----------
    as_cmap : bool
        If True, the return value is a matplotlib colormap rather than a
        list of discrete colors.

    Returns
    -------
    pal or cmap : list of colors or matplotlib colormap
        Object that can be passed to plotting functions.

    See Also
    --------
    diverging_palette : Create a diverging color palette or colormap.
    choose_colorbrewer_palette : Interactively choose palettes from the
                                 colorbrewer set, including diverging palettes.

    """
    pal = []
    if as_cmap:
        cmap = _init_mutable_colormap()

    @interact
    def choose_diverging_palette(
        h_neg=IntSlider(min=0,
                        max=359,
                        value=220),
        h_pos=IntSlider(min=0,
                        max=359,
                        value=10),
        s=IntSlider(min=0, max=99, value=74),
        l=IntSlider(min=0, max=99, value=50),  # noqa: E741
        sep=IntSlider(min=1, max=50, value=10),
        n=(2, 16),
        center=["light", "dark"]
    ):
        if as_cmap:
            colors = diverging_palette(h_neg, h_pos, s, l, sep, 256, center)
            _update_lut(cmap, colors)
            _show_cmap(cmap)
        else:
            pal[:] = diverging_palette(h_neg, h_pos, s, l, sep, n, center)
            palplot(pal)

    if as_cmap:
        return cmap
    return pal


def choose_cubehelix_palette(as_cmap=False):
    """Launch an interactive widget to create a sequential cubehelix palette.

    This corresponds with the :func:`cubehelix_palette` function. This kind
    of palette is good for data that range between relatively uninteresting
    low values and interesting high values. The cubehelix system allows the
    palette to have more hue variance across the range, which can be helpful
    for distinguishing a wider range of values.

    Requires IPython 2+ and must be used in the notebook.

    Parameters
    ----------
    as_cmap : bool
        If True, the return value is a matplotlib colormap rather than a
        list of discrete colors.

    Returns
    -------
    pal or cmap : list of colors or matplotlib colormap
        Object that can be passed to plotting functions.

    See Also
    --------
    cubehelix_palette : Create a sequential palette or colormap using the
                        cubehelix system.

    """
    pal = []
    if as_cmap:
        cmap = _init_mutable_colormap()

    @interact
    def choose_cubehelix(n_colors=IntSlider(min=2, max=16, value=9),
                         start=FloatSlider(min=0, max=3, value=0),
                         rot=FloatSlider(min=-1, max=1, value=.4),
                         gamma=FloatSlider(min=0, max=5, value=1),
                         hue=FloatSlider(min=0, max=1, value=.8),
                         light=FloatSlider(min=0, max=1, value=.85),
                         dark=FloatSlider(min=0, max=1, value=.15),
                         reverse=False):

        if as_cmap:
            colors = cubehelix_palette(256, start, rot, gamma,
                                       hue, light, dark, reverse)
            _update_lut(cmap, np.c_[colors, np.ones(256)])
            _show_cmap(cmap)
        else:
            pal[:] = cubehelix_palette(n_colors, start, rot, gamma,
                                       hue, light, dark, reverse)
            palplot(pal)

    if as_cmap:
        return cmap
    return pal


================================================
FILE: setup.cfg
================================================
[flake8]
max-line-length = 88
exclude = seaborn/cm.py,seaborn/external
ignore = E741,F522,W503

[mypy]
# Currently this ignores pandas and matplotlib
# We may want to make custom stub files for the parts we use
# I have found the available third party stubs to be less
# complete than they would need to be useful
ignore_missing_imports = True

[coverage:run]
omit =
    seaborn/widgets.py
    seaborn/external/*
    seaborn/colors/*
    seaborn/cm.py
    seaborn/conftest.py

[coverage:report]
exclude_lines =
    pragma: no cover
    if TYPE_CHECKING:
    raise NotImplementedError


================================================
FILE: tests/__init__.py
================================================


================================================
FILE: tests/_core/__init__.py
================================================


================================================
FILE: tests/_core/test_data.py
================================================
import functools
import numpy as np
import pandas as pd

import pytest
from numpy.testing import assert_array_equal
from pandas.testing import assert_series_equal

from seaborn._core.data import PlotData


assert_vector_equal = functools.partial(assert_series_equal, check_names=False)


class TestPlotData:

    @pytest.fixture
    def long_variables(self):
        variables = dict(x="x", y="y", color="a", size="z", style="s_cat")
        return variables

    def test_named_vectors(self, long_df, long_variables):

        p = PlotData(long_df, long_variables)
        assert p.source_data is long_df
        assert p.source_vars is long_variables
        for key, val in long_variables.items():
            assert p.names[key] == val
            assert_vector_equal(p.frame[key], long_df[val])

    def test_named_and_given_vectors(self, long_df, long_variables):

        long_variables["y"] = long_df["b"]
        long_variables["size"] = long_df["z"].to_numpy()

        p = PlotData(long_df, long_variables)

        assert_vector_equal(p.frame["color"], long_df[long_variables["color"]])
        assert_vector_equal(p.frame["y"], long_df["b"])
        assert_vector_equal(p.frame["size"], long_df["z"])

        assert p.names["color"] == long_variables["color"]
        assert p.names["y"] == "b"
        assert p.names["size"] is None

        assert p.ids["color"] == long_variables["color"]
        assert p.ids["y"] == "b"
        assert p.ids["size"] == id(long_variables["size"])

    def test_index_as_variable(self, long_df, long_variables):

        index = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int)
        long_variables["x"] = "i"
        p = PlotData(long_df.set_index(index), long_variables)

        assert p.names["x"] == p.ids["x"] == "i"
        assert_vector_equal(p.frame["x"], pd.Series(index, index))

    def test_multiindex_as_variables(self, long_df, long_variables):

        index_i = pd.Index(np.arange(len(long_df)) * 2 + 10, name="i", dtype=int)
        index_j = pd.Index(np.arange(len(long_df)) * 3 + 5, name="j", dtype=int)
        index = pd.MultiIndex.from_arrays([index_i, index_j])
        long_variables.update({"x": "i", "y": "j"})

        p = PlotData(long_df.set_index(index), long_variables)
        assert_vector_equal(p.frame["x"], pd.Series(index_i, index))
        assert_vector_equal(p.frame["y"], pd.Series(index_j, index))

    def test_int_as_variable_key(self, rng):

        df = pd.DataFrame(rng.uniform(size=(10, 3)))

        var = "x"
        key = 2

        p = PlotData(df, {var: key})
        assert_vector_equal(p.frame[var], df[key])
        assert p.names[var] == p.ids[var] == str(key)

    def test_int_as_variable_value(self, long_df):

        p = PlotData(long_df, {"x": 0, "y": "y"})
        assert (p.frame["x"] == 0).all()
        assert p.names["x"] is None
        assert p.ids["x"] == id(0)

    def test_tuple_as_variable_key(self, rng):

        cols = pd.MultiIndex.from_product([("a", "b", "c"), ("x", "y")])
        df = pd.DataFrame(rng.uniform(size=(10, 6)), columns=cols)

        var = "color"
        key = ("b", "y")
        p = PlotData(df, {var: key})
        assert_vector_equal(p.frame[var], df[key])
        assert p.names[var] == p.ids[var] == str(key)

    def test_dict_as_data(self, long_dict, long_variables):

        p = PlotData(long_dict, long_variables)
        assert p.source_data is long_dict
        for key, val in long_variables.items():
            assert_vector_equal(p.frame[key], pd.Series(long_dict[val]))

    @pytest.mark.parametrize(
        "vector_type",
        ["series", "numpy", "list"],
    )
    def test_vectors_various_types(self, long_df, long_variables, vector_type):

        variables = {key: long_df[val] for key, val in long_variables.items()}
        if vector_type == "numpy":
            variables = {key: val.to_numpy() for key, val in variables.items()}
        elif vector_type == "list":
            variables = {key: val.to_list() for key, val in variables.items()}

        p = PlotData(None, variables)

        assert list(p.names) == list(long_variables)
        if vector_type == "series":
            assert p.source_vars is variables
            assert p.names == p.ids == {key: val.name for key, val in variables.items()}
        else:
            assert p.names == {key: None for key in variables}
            assert p.ids == {key: id(val) for key, val in variables.items()}

        for key, val in long_variables.items():
            if vector_type == "series":
                assert_vector_equal(p.frame[key], long_df[val])
            else:
                assert_array_equal(p.frame[key], long_df[val])

    def test_none_as_variable_value(self, long_df):

        p = PlotData(long_df, {"x": "z", "y": None})
        assert list(p.frame.columns) == ["x"]
        assert p.names == p.ids == {"x": "z"}

    def test_frame_and_vector_mismatched_lengths(self, long_df):

        vector = np.arange(len(long_df) * 2)
        with pytest.raises(ValueError):
            PlotData(long_df, {"x": "x", "y": vector})

    @pytest.mark.parametrize(
        "arg", [{}, pd.DataFrame()],
    )
    def test_empty_data_input(self, arg):

        p = PlotData(arg, {})
        assert p.frame.empty
        assert not p.names

        if not isinstance(arg, pd.DataFrame):
            p = PlotData(None, dict(x=arg, y=arg))
            assert p.frame.empty
            assert not p.names

    def test_index_alignment_series_to_dataframe(self):

        x = [1, 2, 3]
        x_index = pd.Index(x, dtype=int)

        y_values = [3, 4, 5]
        y_index = pd.Index(y_values, dtype=int)
        y = pd.Series(y_values, y_index, name="y")

        data = pd.DataFrame(dict(x=x), index=x_index)

        p = PlotData(data, {"x": "x", "y": y})

        x_col_expected = pd.Series([1, 2, 3, np.nan, np.nan], np.arange(1, 6))
        y_col_expected = pd.Series([np.nan, np.nan, 3, 4, 5], np.arange(1, 6))
        assert_vector_equal(p.frame["x"], x_col_expected)
        assert_vector_equal(p.frame["y"], y_col_expected)

    def test_index_alignment_between_series(self):

        x_index = [1, 2, 3]
        x_values = [10, 20, 30]
        x = pd.Series(x_values, x_index, name="x")

        y_index = [3, 4, 5]
        y_values = [300, 400, 500]
        y = pd.Series(y_values, y_index, name="y")

        p = PlotData(None, {"x": x, "y": y})

        idx_expected = [1, 2, 3, 4, 5]
        x_col_expected = pd.Series([10, 20, 30, np.nan, np.nan], idx_expected)
        y_col_expected = pd.Series([np.nan, np.nan, 300, 400, 500], idx_expected)
        assert_vector_equal(p.frame["x"], x_col_expected)
        assert_vector_equal(p.frame["y"], y_col_expected)

    def test_key_not_in_data_raises(self, long_df):

        var = "x"
        key = "what"
        msg = f"Could not interpret value `{key}` for `{var}`. An entry with this name"
        with pytest.raises(ValueError, match=msg):
            PlotData(long_df, {var: key})

    def test_key_with_no_data_raises(self):

        var = "x"
        key = "what"
        msg = f"Could not interpret value `{key}` for `{var}`. Value is a string,"
        with pytest.raises(ValueError, match=msg):
            PlotData(None, {var: key})

    def test_data_vector_different_lengths_raises(self, long_df):

        vector = np.arange(len(long_df) - 5)
        msg = "Length of ndarray vectors must match length of `data`"
        with pytest.raises(ValueError, match=msg):
            PlotData(long_df, {"y": vector})

    def test_undefined_variables_raise(self, long_df):

        with pytest.raises(ValueError):
            PlotData(long_df, dict(x="not_in_df"))

        with pytest.raises(ValueError):
            PlotData(long_df, dict(x="x", y="not_in_df"))

        with pytest.raises(ValueError):
            PlotData(long_df, dict(x="x", y="y", color="not_in_df"))

    def test_contains_operation(self, long_df):

        p = PlotData(long_df, {"x": "y", "color": long_df["a"]})
        assert "x" in p
        assert "y" not in p
        assert "color" in p

    def test_join_add_variable(self, long_df):

        v1 = {"x": "x", "y": "f"}
        v2 = {"color": "a"}

        p1 = PlotData(long_df, v1)
        p2 = p1.join(None, v2)

        for var, key in dict(**v1, **v2).items():
            assert var in p2
            assert p2.names[var] == key
            assert_vector_equal(p2.frame[var], long_df[key])

    def test_join_replace_variable(self, long_df):

        v1 = {"x": "x", "y": "y"}
        v2 = {"y": "s"}

        p1 = PlotData(long_df, v1)
        p2 = p1.join(None, v2)

        variables = v1.copy()
        variables.update(v2)

        for var, key in variables.items():
            assert var in p2
            assert p2.names[var] == key
            assert_vector_equal(p2.frame[var], long_df[key])

    def test_join_remove_variable(self, long_df):

        variables = {"x": "x", "y": "f"}
        drop_var = "y"

        p1 = PlotData(long_df, variables)
        p2 = p1.join(None, {drop_var: None})

        assert drop_var in p1
        assert drop_var not in p2
        assert drop_var not in p2.frame
        assert drop_var not in p2.names

    def test_join_all_operations(self, long_df):

        v1 = {"x": "x", "y": "y", "color": "a"}
        v2 = {"y": "s", "size": "s", "color": None}

        p1 = PlotData(long_df, v1)
        p2 = p1.join(None, v2)

        for var, key in v2.items():
            if key is None:
                assert var not in p2
            else:
                assert p2.names[var] == key
                assert_vector_equal(p2.frame[var], long_df[key])

    def test_join_all_operations_same_data(self, long_df):

        v1 = {"x": "x", "y": "y", "color": "a"}
        v2 = {"y": "s", "size": "s", "color": None}

        p1 = PlotData(long_df, v1)
        p2 = p1.join(long_df, v2)

        for var, key in v2.items():
            if key is None:
                assert var not in p2
            else:
                assert p2.names[var] == key
                assert_vector_equal(p2.frame[var], long_df[key])

    def test_join_add_variable_new_data(self, long_df):

        d1 = long_df[["x", "y"]]
        d2 = long_df[["a", "s"]]

        v1 = {"x": "x", "y": "y"}
        v2 = {"color": "a"}

        p1 = PlotData(d1, v1)
        p2 = p1.join(d2, v2)

        for var, key in dict(**v1, **v2).items():
            assert p2.names[var] == key
            assert_vector_equal(p2.frame[var], long_df[key])

    def test_join_replace_variable_new_data(self, long_df):

        d1 = long_df[["x", "y"]]
        d2 = long_df[["a", "s"]]

        v1 = {"x": "x", "y": "y"}
        v2 = {"x": "a"}

        p1 = PlotData(d1, v1)
        p2 = p1.join(d2, v2)

        variables = v1.copy()
        variables.update(v2)

        for var, key in variables.items():
            assert p2.names[var] == key
            assert_vector_equal(p2.frame[var], long_df[key])

    def test_join_add_variable_different_index(self, long_df):

        d1 = long_df.iloc[:70]
        d2 = long_df.iloc[30:]

        v1 = {"x": "a"}
        v2 = {"y": "z"}

        p1 = PlotData(d1, v1)
        p2 = p1.join(d2, v2)

        (var1, key1), = v1.items()
        (var2, key2), = v2.items()

        assert_vector_equal(p2.frame.loc[d1.index, var1], d1[key1])
        assert_vector_equal(p2.frame.loc[d2.index, var2], d2[key2])

        assert p2.frame.loc[d2.index.difference(d1.index), var1].isna().all()
        assert p2.frame.loc[d1.index.difference(d2.index), var2].isna().all()

    def test_join_replace_variable_different_index(self, long_df):

        d1 = long_df.iloc[:70]
        d2 = long_df.iloc[30:]

        var = "x"
        k1, k2 = "a", "z"
        v1 = {var: k1}
        v2 = {var: k2}

        p1 = PlotData(d1, v1)
        p2 = p1.join(d2, v2)

        (var1, key1), = v1.items()
        (var2, key2), = v2.items()

        assert_vector_equal(p2.frame.loc[d2.index, var], d2[k2])
        assert p2.frame.loc[d1.index.difference(d2.index), var].isna().all()

    def test_join_subset_data_inherit_variables(self, long_df):

        sub_df = long_df[long_df["a"] == "b"]

        var = "y"
        p1 = PlotData(long_df, {var: var})
        p2 = p1.join(sub_df, None)

        assert_vector_equal(p2.frame.loc[sub_df.index, var], sub_df[var])
        assert p2.frame.loc[long_df.index.difference(sub_df.index), var].isna().all()

    def test_join_multiple_inherits_from_orig(self, rng):

        d1 = pd.DataFrame(dict(a=rng.normal(0, 1, 100), b=rng.normal(0, 1, 100)))
        d2 = pd.DataFrame(dict(a=rng.normal(0, 1, 100)))

        p = PlotData(d1, {"x": "a"}).join(d2, {"y": "a"}).join(None, {"y": "a"})
        assert_vector_equal(p.frame["x"], d1["a"])
        assert_vector_equal(p.frame["y"], d1["a"])

    def test_bad_type(self, flat_list):

        err = "Data source must be a DataFrame or Mapping"
        with pytest.raises(TypeError, match=err):
            PlotData(flat_list, {})

    @pytest.mark.skipif(
        condition=not hasattr(pd.api, "interchange"),
        reason="Tests behavior assuming support for dataframe interchange"
    )
    def test_data_interchange(self, mock_long_df, long_df):

        variables = {"x": "x", "y": "z", "color": "a"}
        p = PlotData(mock_long_df, variables)
        for var, col in variables.items():
            assert_vector_equal(p.frame[var], long_df[col])

        p = PlotData(mock_long_df, {**variables, "color": long_df["a"]})
        for var, col in variables.items():
            assert_vector_equal(p.frame[var], long_df[col])

    def test_data_interchange_failure(self, mock_long_df):

        mock_long_df._data = None  # Break to_pandas()
        with pytest.raises(RuntimeError, match="Encountered an exception"):
            PlotData(mock_long_df, {"x": "x"})

    @pytest.mark.skipif(
        condition=hasattr(pd.api, "interchange"),
        reason="Tests graceful failure without support for dataframe interchange"
    )
    def test_data_interchange_support_test(self, mock_long_df):

        with pytest.raises(TypeError, match="Support for non-pandas DataFrame"):
            PlotData(mock_long_df, {"x": "x"})


================================================
FILE: tests/_core/test_groupby.py
================================================

import numpy as np
import pandas as pd

import pytest
from numpy.testing import assert_array_equal

from seaborn._core.groupby import GroupBy


@pytest.fixture
def df():

    return pd.DataFrame(
        columns=["a", "b", "x", "y"],
        data=[
            ["a", "g", 1, .2],
            ["b", "h", 3, .5],
            ["a", "f", 2, .8],
            ["a", "h", 1, .3],
            ["b", "f", 2, .4],
        ]
    )


def test_init_from_list():
    g = GroupBy(["a", "c", "b"])
    assert g.order == {"a": None, "c": None, "b": None}


def test_init_from_dict():
    order = {"a": [3, 2, 1], "c": None, "b": ["x", "y", "z"]}
    g = GroupBy(order)
    assert g.order == order


def test_init_requires_order():

    with pytest.raises(ValueError, match="GroupBy requires at least one"):
        GroupBy([])


def test_at_least_one_grouping_variable_required(df):

    with pytest.raises(ValueError, match="No grouping variables are present"):
        GroupBy(["z"]).agg(df, x="mean")


def test_agg_one_grouper(df):

    res = GroupBy(["a"]).agg(df, {"y": "max"})
    assert_array_equal(res.index, [0, 1])
    assert_array_equal(res.columns, ["a", "y"])
    assert_array_equal(res["a"], ["a", "b"])
    assert_array_equal(res["y"], [.8, .5])


def test_agg_two_groupers(df):

    res = GroupBy(["a", "x"]).agg(df, {"y": "min"})
    assert_array_equal(res.index, [0, 1, 2, 3, 4, 5])
    assert_array_equal(res.columns, ["a", "x", "y"])
    assert_array_equal(res["a"], ["a", "a", "a", "b", "b", "b"])
    assert_array_equal(res["x"], [1, 2, 3, 1, 2, 3])
    assert_array_equal(res["y"], [.2, .8, np.nan, np.nan, .4, .5])


def test_agg_two_groupers_ordered(df):

    order = {"b": ["h", "g", "f"], "x": [3, 2, 1]}
    res = GroupBy(order).agg(df, {"a": "min", "y": lambda x: x.iloc[0]})
    assert_array_equal(res.index, [0, 1, 2, 3, 4, 5, 6, 7, 8])
    assert_array_equal(res.columns, ["a", "b", "x", "y"])
    assert_array_equal(res["b"], ["h", "h", "h", "g", "g", "g", "f", "f", "f"])
    assert_array_equal(res["x"], [3, 2, 1, 3, 2, 1, 3, 2, 1])

    T, F = True, False
    assert_array_equal(res["a"].isna(), [F, T, F, T, T, F, T, F, T])
    assert_array_equal(res["a"].dropna(), ["b", "a", "a", "a"])
    assert_array_equal(res["y"].dropna(), [.5, .3, .2, .8])


def test_apply_no_grouper(df):

    df = df[["x", "y"]]
    res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x"))
    assert_array_equal(res.columns, ["x", "y"])
    assert_array_equal(res["x"], df["x"].sort_values())
    assert_array_equal(res["y"], df.loc[np.argsort(df["x"]), "y"])


def test_apply_one_grouper(df):

    res = GroupBy(["a"]).apply(df, lambda x: x.sort_values("x"))
    assert_array_equal(res.index, [0, 1, 2, 3, 4])
    assert_array_equal(res.columns, ["a", "b", "x", "y"])
    assert_array_equal(res["a"], ["a", "a", "a", "b", "b"])
    assert_array_equal(res["b"], ["g", "h", "f", "f", "h"])
    assert_array_equal(res["x"], [1, 1, 2, 2, 3])


def test_apply_mutate_columns(df):

    xx = np.arange(0, 5)
    hats = []

    def polyfit(df):
        fit = np.polyfit(df["x"], df["y"], 1)
        hat = np.polyval(fit, xx)
        hats.append(hat)
        return pd.DataFrame(dict(x=xx, y=hat))

    res = GroupBy(["a"]).apply(df, polyfit)
    assert_array_equal(res.index, np.arange(xx.size * 2))
    assert_array_equal(res.columns, ["a", "x", "y"])
    assert_array_equal(res["a"], ["a"] * xx.size + ["b"] * xx.size)
    assert_array_equal(res["x"], xx.tolist() + xx.tolist())
    assert_array_equal(res["y"], np.concatenate(hats))


def test_apply_replace_columns(df):

    def add_sorted_cumsum(df):

        x = df["x"].sort_values()
        z = df.loc[x.index, "y"].cumsum()
        return pd.DataFrame(dict(x=x.values, z=z.values))

    res = GroupBy(["a"]).apply(df, add_sorted_cumsum)
    assert_array_equal(res.index, df.index)
    assert_array_equal(res.columns, ["a", "x", "z"])
    assert_array_equal(res["a"], ["a", "a", "a", "b", "b"])
    assert_array_equal(res["x"], [1, 1, 2, 2, 3])
    assert_array_equal(res["z"], [.2, .5, 1.3, .4, .9])


================================================
FILE: tests/_core/test_moves.py
================================================

from itertools import product

import numpy as np
import pandas as pd
from pandas.testing import assert_series_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn._core.moves import Dodge, Jitter, Shift, Stack, Norm
from seaborn._core.rules import categorical_order
from seaborn._core.groupby import GroupBy

import pytest


class MoveFixtures:

    @pytest.fixture
    def df(self, rng):

        n = 50
        data = {
            "x": rng.choice([0., 1., 2., 3.], n),
            "y": rng.normal(0, 1, n),
            "grp2": rng.choice(["a", "b"], n),
            "grp3": rng.choice(["x", "y", "z"], n),
            "width": 0.8,
            "baseline": 0,
        }
        return pd.DataFrame(data)

    @pytest.fixture
    def toy_df(self):

        data = {
            "x": [0, 0, 1],
            "y": [1, 2, 3],
            "grp": ["a", "b", "b"],
            "width": .8,
            "baseline": 0,
        }
        return pd.DataFrame(data)

    @pytest.fixture
    def toy_df_widths(self, toy_df):

        toy_df["width"] = [.8, .2, .4]
        return toy_df

    @pytest.fixture
    def toy_df_facets(self):

        data = {
            "x": [0, 0, 1, 0, 1, 2],
            "y": [1, 2, 3, 1, 2, 3],
            "grp": ["a", "b", "a", "b", "a", "b"],
            "col": ["x", "x", "x", "y", "y", "y"],
            "width": .8,
            "baseline": 0,
        }
        return pd.DataFrame(data)


class TestJitter(MoveFixtures):

    def get_groupby(self, data, orient):
        other = {"x": "y", "y": "x"}[orient]
        variables = [v for v in data if v not in [other, "width"]]
        return GroupBy(variables)

    def check_same(self, res, df, *cols):
        for col in cols:
            assert_series_equal(res[col], df[col])

    def check_pos(self, res, df, var, limit):

        assert (res[var] != df[var]).all()
        assert (res[var] < df[var] + limit / 2).all()
        assert (res[var] > df[var] - limit / 2).all()

    def test_default(self, df):

        orient = "x"
        groupby = self.get_groupby(df, orient)
        res = Jitter()(df, groupby, orient, {})
        self.check_same(res, df, "y", "grp2", "width")
        self.check_pos(res, df, "x", 0.2 * df["width"])
        assert (res["x"] - df["x"]).abs().min() > 0

    def test_width(self, df):

        width = .4
        orient = "x"
        groupby = self.get_groupby(df, orient)
        res = Jitter(width=width)(df, groupby, orient, {})
        self.check_same(res, df, "y", "grp2", "width")
        self.check_pos(res, df, "x", width * df["width"])

    def test_x(self, df):

        val = .2
        orient = "x"
        groupby = self.get_groupby(df, orient)
        res = Jitter(x=val)(df, groupby, orient, {})
        self.check_same(res, df, "y", "grp2", "width")
        self.check_pos(res, df, "x", val)

    def test_y(self, df):

        val = .2
        orient = "x"
        groupby = self.get_groupby(df, orient)
        res = Jitter(y=val)(df, groupby, orient, {})
        self.check_same(res, df, "x", "grp2", "width")
        self.check_pos(res, df, "y", val)

    def test_seed(self, df):

        kws = dict(width=.2, y=.1, seed=0)
        orient = "x"
        groupby = self.get_groupby(df, orient)
        res1 = Jitter(**kws)(df, groupby, orient, {})
        res2 = Jitter(**kws)(df, groupby, orient, {})
        for var in "xy":
            assert_series_equal(res1[var], res2[var])


class TestDodge(MoveFixtures):

    # First some very simple toy examples

    def test_default(self, toy_df):

        groupby = GroupBy(["x", "grp"])
        res = Dodge()(toy_df, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3]),
        assert_array_almost_equal(res["x"], [-.2, .2, 1.2])
        assert_array_almost_equal(res["width"], [.4, .4, .4])

    def test_fill(self, toy_df):

        groupby = GroupBy(["x", "grp"])
        res = Dodge(empty="fill")(toy_df, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3]),
        assert_array_almost_equal(res["x"], [-.2, .2, 1])
        assert_array_almost_equal(res["width"], [.4, .4, .8])

    def test_drop(self, toy_df):

        groupby = GroupBy(["x", "grp"])
        res = Dodge("drop")(toy_df, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3])
        assert_array_almost_equal(res["x"], [-.2, .2, 1])
        assert_array_almost_equal(res["width"], [.4, .4, .4])

    def test_gap(self, toy_df):

        groupby = GroupBy(["x", "grp"])
        res = Dodge(gap=.25)(toy_df, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3])
        assert_array_almost_equal(res["x"], [-.2, .2, 1.2])
        assert_array_almost_equal(res["width"], [.3, .3, .3])

    def test_widths_default(self, toy_df_widths):

        groupby = GroupBy(["x", "grp"])
        res = Dodge()(toy_df_widths, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3])
        assert_array_almost_equal(res["x"], [-.08, .32, 1.1])
        assert_array_almost_equal(res["width"], [.64, .16, .2])

    def test_widths_fill(self, toy_df_widths):

        groupby = GroupBy(["x", "grp"])
        res = Dodge(empty="fill")(toy_df_widths, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3])
        assert_array_almost_equal(res["x"], [-.08, .32, 1])
        assert_array_almost_equal(res["width"], [.64, .16, .4])

    def test_widths_drop(self, toy_df_widths):

        groupby = GroupBy(["x", "grp"])
        res = Dodge(empty="drop")(toy_df_widths, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3])
        assert_array_almost_equal(res["x"], [-.08, .32, 1])
        assert_array_almost_equal(res["width"], [.64, .16, .2])

    def test_faceted_default(self, toy_df_facets):

        groupby = GroupBy(["x", "grp", "col"])
        res = Dodge()(toy_df_facets, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3])
        assert_array_almost_equal(res["x"], [-.2, .2, .8, .2, .8, 2.2])
        assert_array_almost_equal(res["width"], [.4] * 6)

    def test_faceted_fill(self, toy_df_facets):

        groupby = GroupBy(["x", "grp", "col"])
        res = Dodge(empty="fill")(toy_df_facets, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3])
        assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2])
        assert_array_almost_equal(res["width"], [.4, .4, .8, .8, .8, .8])

    def test_faceted_drop(self, toy_df_facets):

        groupby = GroupBy(["x", "grp", "col"])
        res = Dodge(empty="drop")(toy_df_facets, groupby, "x", {})

        assert_array_equal(res["y"], [1, 2, 3, 1, 2, 3])
        assert_array_almost_equal(res["x"], [-.2, .2, 1, 0, 1, 2])
        assert_array_almost_equal(res["width"], [.4] * 6)

    def test_orient(self, toy_df):

        df = toy_df.assign(x=toy_df["y"], y=toy_df["x"])

        groupby = GroupBy(["y", "grp"])
        res = Dodge("drop")(df, groupby, "y", {})

        assert_array_equal(res["x"], [1, 2, 3])
        assert_array_almost_equal(res["y"], [-.2, .2, 1])
        assert_array_almost_equal(res["width"], [.4, .4, .4])

    # Now tests with slightly more complicated data

    @pytest.mark.parametrize("grp", ["grp2", "grp3"])
    def test_single_semantic(self, df, grp):

        groupby = GroupBy(["x", grp])
        res = Dodge()(df, groupby, "x", {})

        levels = categorical_order(df[grp])
        w, n = 0.8, len(levels)

        shifts = np.linspace(0, w - w / n, n)
        shifts -= shifts.mean()

        assert_series_equal(res["y"], df["y"])
        assert_series_equal(res["width"], df["width"] / n)

        for val, shift in zip(levels, shifts):
            rows = df[grp] == val
            assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift)

    def test_two_semantics(self, df):

        groupby = GroupBy(["x", "grp2", "grp3"])
        res = Dodge()(df, groupby, "x", {})

        levels = categorical_order(df["grp2"]), categorical_order(df["grp3"])
        w, n = 0.8, len(levels[0]) * len(levels[1])

        shifts = np.linspace(0, w - w / n, n)
        shifts -= shifts.mean()

        assert_series_equal(res["y"], df["y"])
        assert_series_equal(res["width"], df["width"] / n)

        for (v2, v3), shift in zip(product(*levels), shifts):
            rows = (df["grp2"] == v2) & (df["grp3"] == v3)
            assert_series_equal(res.loc[rows, "x"], df.loc[rows, "x"] + shift)


class TestStack(MoveFixtures):

    def test_basic(self, toy_df):

        groupby = GroupBy(["color", "group"])
        res = Stack()(toy_df, groupby, "x", {})

        assert_array_equal(res["x"], [0, 0, 1])
        assert_array_equal(res["y"], [1, 3, 3])
        assert_array_equal(res["baseline"], [0, 1, 0])

    def test_faceted(self, toy_df_facets):

        groupby = GroupBy(["color", "group"])
        res = Stack()(toy_df_facets, groupby, "x", {})

        assert_array_equal(res["x"], [0, 0, 1, 0, 1, 2])
        assert_array_equal(res["y"], [1, 3, 3, 1, 2, 3])
        assert_array_equal(res["baseline"], [0, 1, 0, 0, 0, 0])

    def test_misssing_data(self, toy_df):

        df = pd.DataFrame({
            "x": [0, 0, 0],
            "y": [2, np.nan, 1],
            "baseline": [0, 0, 0],
        })
        res = Stack()(df, None, "x", {})
        assert_array_equal(res["y"], [2, np.nan, 3])
        assert_array_equal(res["baseline"], [0, np.nan, 2])

    def test_baseline_homogeneity_check(self, toy_df):

        toy_df["baseline"] = [0, 1, 2]
        groupby = GroupBy(["color", "group"])
        move = Stack()
        err = "Stack move cannot be used when baselines"
        with pytest.raises(RuntimeError, match=err):
            move(toy_df, groupby, "x", {})


class TestShift(MoveFixtures):

    def test_default(self, toy_df):

        gb = GroupBy(["color", "group"])
        res = Shift()(toy_df, gb, "x", {})
        for col in toy_df:
            assert_series_equal(toy_df[col], res[col])

    @pytest.mark.parametrize("x,y", [(.3, 0), (0, .2), (.1, .3)])
    def test_moves(self, toy_df, x, y):

        gb = GroupBy(["color", "group"])
        res = Shift(x=x, y=y)(toy_df, gb, "x", {})
        assert_array_equal(res["x"], toy_df["x"] + x)
        assert_array_equal(res["y"], toy_df["y"] + y)


class TestNorm(MoveFixtures):

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_default_no_groups(self, df, orient):

        other = {"x": "y", "y": "x"}[orient]
        gb = GroupBy(["null"])
        res = Norm()(df, gb, orient, {})
        assert res[other].max() == pytest.approx(1)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_default_groups(self, df, orient):

        other = {"x": "y", "y": "x"}[orient]
        gb = GroupBy(["grp2"])
        res = Norm()(df, gb, orient, {})
        for _, grp in res.groupby("grp2"):
            assert grp[other].max() == pytest.approx(1)

    def test_sum(self, df):

        gb = GroupBy(["null"])
        res = Norm("sum")(df, gb, "x", {})
        assert res["y"].sum() == pytest.approx(1)

    def test_where(self, df):

        gb = GroupBy(["null"])
        res = Norm(where="x == 2")(df, gb, "x", {})
        assert res.loc[res["x"] == 2, "y"].max() == pytest.approx(1)

    def test_percent(self, df):

        gb = GroupBy(["null"])
        res = Norm(percent=True)(df, gb, "x", {})
        assert res["y"].max() == pytest.approx(100)


================================================
FILE: tests/_core/test_plot.py
================================================
import io
import xml
import functools
import itertools
import warnings

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from PIL import Image

import pytest
from pandas.testing import assert_frame_equal, assert_series_equal
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn._core.plot import Plot, PlotConfig, Default
from seaborn._core.scales import Continuous, Nominal, Temporal
from seaborn._core.moves import Move, Shift, Dodge
from seaborn._core.rules import categorical_order
from seaborn._core.exceptions import PlotSpecError
from seaborn._marks.base import Mark
from seaborn._stats.base import Stat
from seaborn._marks.dot import Dot
from seaborn._stats.aggregation import Agg
from seaborn.utils import _version_predates

assert_vector_equal = functools.partial(
    # TODO do we care about int/float dtype consistency?
    # Eventually most variables become floats ... but does it matter when?
    # (Or rather, does it matter if it happens too early?)
    assert_series_equal, check_names=False, check_dtype=False,
)


def assert_gridspec_shape(ax, nrows=1, ncols=1):

    gs = ax.get_gridspec()
    assert gs.nrows == nrows
    assert gs.ncols == ncols


class MockMark(Mark):

    _grouping_props = ["color"]

    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)
        self.passed_keys = []
        self.passed_data = []
        self.passed_axes = []
        self.passed_scales = None
        self.passed_orient = None
        self.n_splits = 0

    def _plot(self, split_gen, scales, orient):

        for keys, data, ax in split_gen():
            self.n_splits += 1
            self.passed_keys.append(keys)
            self.passed_data.append(data)
            self.passed_axes.append(ax)

        self.passed_scales = scales
        self.passed_orient = orient

    def _legend_artist(self, variables, value, scales):

        a = mpl.lines.Line2D([], [])
        a.variables = variables
        a.value = value
        return a


class TestInit:

    def test_empty(self):

        p = Plot()
        assert p._data.source_data is None
        assert p._data.source_vars == {}

    def test_data_only(self, long_df):

        p = Plot(long_df)
        assert p._data.source_data is long_df
        assert p._data.source_vars == {}

    def test_df_and_named_variables(self, long_df):

        variables = {"x": "a", "y": "z"}
        p = Plot(long_df, **variables)
        for var, col in variables.items():
            assert_vector_equal(p._data.frame[var], long_df[col])
        assert p._data.source_data is long_df
        assert p._data.source_vars.keys() == variables.keys()

    def test_df_and_mixed_variables(self, long_df):

        variables = {"x": "a", "y": long_df["z"]}
        p = Plot(long_df, **variables)
        for var, col in variables.items():
            if isinstance(col, str):
                assert_vector_equal(p._data.frame[var], long_df[col])
            else:
                assert_vector_equal(p._data.frame[var], col)
        assert p._data.source_data is long_df
        assert p._data.source_vars.keys() == variables.keys()

    def test_vector_variables_only(self, long_df):

        variables = {"x": long_df["a"], "y": long_df["z"]}
        p = Plot(**variables)
        for var, col in variables.items():
            assert_vector_equal(p._data.frame[var], col)
        assert p._data.source_data is None
        assert p._data.source_vars.keys() == variables.keys()

    def test_vector_variables_no_index(self, long_df):

        variables = {"x": long_df["a"].to_numpy(), "y": long_df["z"].to_list()}
        p = Plot(**variables)
        for var, col in variables.items():
            assert_vector_equal(p._data.frame[var], pd.Series(col))
            assert p._data.names[var] is None
        assert p._data.source_data is None
        assert p._data.source_vars.keys() == variables.keys()

    def test_data_only_named(self, long_df):

        p = Plot(data=long_df)
        assert p._data.source_data is long_df
        assert p._data.source_vars == {}

    def test_positional_and_named_data(self, long_df):

        err = "`data` given by both name and position"
        with pytest.raises(TypeError, match=err):
            Plot(long_df, data=long_df)

    @pytest.mark.parametrize("var", ["x", "y"])
    def test_positional_and_named_xy(self, long_df, var):

        err = f"`{var}` given by both name and position"
        with pytest.raises(TypeError, match=err):
            Plot(long_df, "a", "b", **{var: "c"})

    def test_positional_data_x_y(self, long_df):

        p = Plot(long_df, "a", "b")
        assert p._data.source_data is long_df
        assert list(p._data.source_vars) == ["x", "y"]

    def test_positional_x_y(self, long_df):

        p = Plot(long_df["a"], long_df["b"])
        assert p._data.source_data is None
        assert list(p._data.source_vars) == ["x", "y"]

    def test_positional_data_x(self, long_df):

        p = Plot(long_df, "a")
        assert p._data.source_data is long_df
        assert list(p._data.source_vars) == ["x"]

    def test_positional_x(self, long_df):

        p = Plot(long_df["a"])
        assert p._data.source_data is None
        assert list(p._data.source_vars) == ["x"]

    @pytest.mark.skipif(
        condition=not hasattr(pd.api, "interchange"),
        reason="Tests behavior assuming support for dataframe interchange"
    )
    def test_positional_interchangeable_dataframe(self, mock_long_df, long_df):

        p = Plot(mock_long_df, x="x")
        assert_frame_equal(p._data.source_data, long_df)

    def test_positional_too_many(self, long_df):

        err = r"Plot\(\) accepts no more than 3 positional arguments \(data, x, y\)"
        with pytest.raises(TypeError, match=err):
            Plot(long_df, "x", "y", "z")

    def test_unknown_keywords(self, long_df):

        err = r"Plot\(\) got unexpected keyword argument\(s\): bad"
        with pytest.raises(TypeError, match=err):
            Plot(long_df, bad="x")


class TestLayerAddition:

    def test_without_data(self, long_df):

        p = Plot(long_df, x="x", y="y").add(MockMark()).plot()
        layer, = p._layers
        assert_frame_equal(p._data.frame, layer["data"].frame, check_dtype=False)

    def test_with_new_variable_by_name(self, long_df):

        p = Plot(long_df, x="x").add(MockMark(), y="y").plot()
        layer, = p._layers
        assert layer["data"].frame.columns.to_list() == ["x", "y"]
        for var in "xy":
            assert_vector_equal(layer["data"].frame[var], long_df[var])

    def test_with_new_variable_by_vector(self, long_df):

        p = Plot(long_df, x="x").add(MockMark(), y=long_df["y"]).plot()
        layer, = p._layers
        assert layer["data"].frame.columns.to_list() == ["x", "y"]
        for var in "xy":
            assert_vector_equal(layer["data"].frame[var], long_df[var])

    def test_with_late_data_definition(self, long_df):

        p = Plot().add(MockMark(), data=long_df, x="x", y="y").plot()
        layer, = p._layers
        assert layer["data"].frame.columns.to_list() == ["x", "y"]
        for var in "xy":
            assert_vector_equal(layer["data"].frame[var], long_df[var])

    def test_with_new_data_definition(self, long_df):

        long_df_sub = long_df.sample(frac=.5)

        p = Plot(long_df, x="x", y="y").add(MockMark(), data=long_df_sub).plot()
        layer, = p._layers
        assert layer["data"].frame.columns.to_list() == ["x", "y"]
        for var in "xy":
            assert_vector_equal(
                layer["data"].frame[var], long_df_sub[var].reindex(long_df.index)
            )

    def test_drop_variable(self, long_df):

        p = Plot(long_df, x="x", y="y").add(MockMark(), y=None).plot()
        layer, = p._layers
        assert layer["data"].frame.columns.to_list() == ["x"]
        assert_vector_equal(layer["data"].frame["x"], long_df["x"], check_dtype=False)

    @pytest.mark.xfail(reason="Need decision on default stat")
    def test_stat_default(self):

        class MarkWithDefaultStat(Mark):
            default_stat = Stat

        p = Plot().add(MarkWithDefaultStat())
        layer, = p._layers
        assert layer["stat"].__class__ is Stat

    def test_stat_nondefault(self):

        class MarkWithDefaultStat(Mark):
            default_stat = Stat

        class OtherMockStat(Stat):
            pass

        p = Plot().add(MarkWithDefaultStat(), OtherMockStat())
        layer, = p._layers
        assert layer["stat"].__class__ is OtherMockStat

    @pytest.mark.parametrize(
        "arg,expected",
        [("x", "x"), ("y", "y"), ("v", "x"), ("h", "y")],
    )
    def test_orient(self, arg, expected):

        class MockStatTrackOrient(Stat):
            def __call__(self, data, groupby, orient, scales):
                self.orient_at_call = orient
                return data

        class MockMoveTrackOrient(Move):
            def __call__(self, data, groupby, orient, scales):
                self.orient_at_call = orient
                return data

        s = MockStatTrackOrient()
        m = MockMoveTrackOrient()
        Plot(x=[1, 2, 3], y=[1, 2, 3]).add(MockMark(), s, m, orient=arg).plot()

        assert s.orient_at_call == expected
        assert m.orient_at_call == expected

    def test_variable_list(self, long_df):

        p = Plot(long_df, x="x", y="y")
        assert p._variables == ["x", "y"]

        p = Plot(long_df).add(MockMark(), x="x", y="y")
        assert p._variables == ["x", "y"]

        p = Plot(long_df, y="x", color="a").add(MockMark(), x="y")
        assert p._variables == ["y", "color", "x"]

        p = Plot(long_df, x="x", y="y", color="a").add(MockMark(), color=None)
        assert p._variables == ["x", "y", "color"]

        p = (
            Plot(long_df, x="x", y="y")
            .add(MockMark(), color="a")
            .add(MockMark(), alpha="s")
        )
        assert p._variables == ["x", "y", "color", "alpha"]

        p = Plot(long_df, y="x").pair(x=["a", "b"])
        assert p._variables == ["y", "x0", "x1"]

    def test_type_checks(self):

        p = Plot()
        with pytest.raises(TypeError, match="mark must be a Mark instance"):
            p.add(MockMark)

        class MockStat(Stat):
            pass

        class MockMove(Move):
            pass

        err = "Transforms must have at most one Stat type"

        with pytest.raises(TypeError, match=err):
            p.add(MockMark(), MockStat)

        with pytest.raises(TypeError, match=err):
            p.add(MockMark(), MockMove(), MockStat())

        with pytest.raises(TypeError, match=err):
            p.add(MockMark(), MockMark(), MockStat())


class TestScaling:

    def test_inference(self, long_df):

        for col, scale_type in zip("zat", ["Continuous", "Nominal", "Temporal"]):
            p = Plot(long_df, x=col, y=col).add(MockMark()).plot()
            for var in "xy":
                assert p._scales[var].__class__.__name__ == scale_type

    def test_inference_from_layer_data(self):

        p = Plot().add(MockMark(), x=["a", "b", "c"]).plot()
        assert p._scales["x"]("b") == 1

    def test_inference_joins(self):

        p = (
            Plot(y=pd.Series([1, 2, 3, 4]))
            .add(MockMark(), x=pd.Series([1, 2]))
            .add(MockMark(), x=pd.Series(["a", "b"], index=[2, 3]))
            .plot()
        )
        assert p._scales["x"]("a") == 2

    def test_inferred_categorical_converter(self):

        p = Plot(x=["b", "c", "a"]).add(MockMark()).plot()
        ax = p._figure.axes[0]
        assert ax.xaxis.convert_units("c") == 1

    def test_explicit_categorical_converter(self):

        p = Plot(y=[2, 1, 3]).scale(y=Nominal()).add(MockMark()).plot()
        ax = p._figure.axes[0]
        assert ax.yaxis.convert_units("3") == 2

    @pytest.mark.xfail(reason="Temporal auto-conversion not implemented")
    def test_categorical_as_datetime(self):

        dates = ["1970-01-03", "1970-01-02", "1970-01-04"]
        p = Plot(x=dates).scale(...).add(MockMark()).plot()
        p  # TODO
        ...

    def test_faceted_log_scale(self):

        p = Plot(y=[1, 10]).facet(col=["a", "b"]).scale(y="log").plot()
        for ax in p._figure.axes:
            xfm = ax.yaxis.get_transform().transform
            assert_array_equal(xfm([1, 10, 100]), [0, 1, 2])

    def test_paired_single_log_scale(self):

        x0, x1 = [1, 2, 3], [1, 10, 100]
        p = Plot().pair(x=[x0, x1]).scale(x1="log").plot()
        ax_lin, ax_log = p._figure.axes
        xfm_lin = ax_lin.xaxis.get_transform().transform
        assert_array_equal(xfm_lin([1, 10, 100]), [1, 10, 100])
        xfm_log = ax_log.xaxis.get_transform().transform
        assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2])

    def test_paired_with_common_fallback(self):

        x0, x1 = [1, 2, 3], [1, 10, 100]
        p = Plot().pair(x=[x0, x1]).scale(x="pow", x1="log").plot()
        ax_pow, ax_log = p._figure.axes
        xfm_pow = ax_pow.xaxis.get_transform().transform
        assert_array_equal(xfm_pow([1, 2, 3]), [1, 4, 9])
        xfm_log = ax_log.xaxis.get_transform().transform
        assert_array_equal(xfm_log([1, 10, 100]), [0, 1, 2])

    @pytest.mark.xfail(reason="Custom log scale needs log name for consistency")
    def test_log_scale_name(self):

        p = Plot().scale(x="log").plot()
        ax = p._figure.axes[0]
        assert ax.get_xscale() == "log"
        assert ax.get_yscale() == "linear"

    def test_mark_data_log_transform_is_inverted(self, long_df):

        col = "z"
        m = MockMark()
        Plot(long_df, x=col).scale(x="log").add(m).plot()
        assert_vector_equal(m.passed_data[0]["x"], long_df[col])

    def test_mark_data_log_transfrom_with_stat(self, long_df):

        class Mean(Stat):
            group_by_orient = True

            def __call__(self, data, groupby, orient, scales):
                other = {"x": "y", "y": "x"}[orient]
                return groupby.agg(data, {other: "mean"})

        col = "z"
        grouper = "a"
        m = MockMark()
        s = Mean()

        Plot(long_df, x=grouper, y=col).scale(y="log").add(m, s).plot()

        expected = (
            long_df[col]
            .pipe(np.log)
            .groupby(long_df[grouper], sort=False)
            .mean()
            .pipe(np.exp)
            .reset_index(drop=True)
        )
        assert_vector_equal(m.passed_data[0]["y"], expected)

    def test_mark_data_from_categorical(self, long_df):

        col = "a"
        m = MockMark()
        Plot(long_df, x=col).add(m).plot()

        levels = categorical_order(long_df[col])
        level_map = {x: float(i) for i, x in enumerate(levels)}
        assert_vector_equal(m.passed_data[0]["x"], long_df[col].map(level_map))

    def test_mark_data_from_datetime(self, long_df):

        col = "t"
        m = MockMark()
        Plot(long_df, x=col).add(m).plot()

        expected = long_df[col].map(mpl.dates.date2num)
        assert_vector_equal(m.passed_data[0]["x"], expected)

    def test_computed_var_ticks(self, long_df):

        class Identity(Stat):
            def __call__(self, df, groupby, orient, scales):
                other = {"x": "y", "y": "x"}[orient]
                return df.assign(**{other: df[orient]})

        tick_locs = [1, 2, 5]
        scale = Continuous().tick(at=tick_locs)
        p = Plot(long_df, "x").add(MockMark(), Identity()).scale(y=scale).plot()
        ax = p._figure.axes[0]
        assert_array_equal(ax.get_yticks(), tick_locs)

    def test_computed_var_transform(self, long_df):

        class Identity(Stat):
            def __call__(self, df, groupby, orient, scales):
                other = {"x": "y", "y": "x"}[orient]
                return df.assign(**{other: df[orient]})

        p = Plot(long_df, "x").add(MockMark(), Identity()).scale(y="log").plot()
        ax = p._figure.axes[0]
        xfm = ax.yaxis.get_transform().transform
        assert_array_equal(xfm([1, 10, 100]), [0, 1, 2])

    def test_explicit_range_with_axis_scaling(self):

        x = [1, 2, 3]
        ymin = [10, 100, 1000]
        ymax = [20, 200, 2000]
        m = MockMark()
        Plot(x=x, ymin=ymin, ymax=ymax).add(m).scale(y="log").plot()
        assert_vector_equal(m.passed_data[0]["ymax"], pd.Series(ymax, dtype=float))

    def test_derived_range_with_axis_scaling(self):

        class AddOne(Stat):
            def __call__(self, df, *args):
                return df.assign(ymax=df["y"] + 1)

        x = y = [1, 10, 100]

        m = MockMark()
        Plot(x, y).add(m, AddOne()).scale(y="log").plot()
        assert_vector_equal(m.passed_data[0]["ymax"], pd.Series([10., 100., 1000.]))

    def test_facet_categories(self):

        m = MockMark()
        p = Plot(x=["a", "b", "a", "c"]).facet(col=["x", "x", "y", "y"]).add(m).plot()
        ax1, ax2 = p._figure.axes
        assert len(ax1.get_xticks()) == 3
        assert len(ax2.get_xticks()) == 3
        assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
        assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3]))

    def test_facet_categories_unshared(self):

        m = MockMark()
        p = (
            Plot(x=["a", "b", "a", "c"])
            .facet(col=["x", "x", "y", "y"])
            .share(x=False)
            .add(m)
            .plot()
        )
        ax1, ax2 = p._figure.axes
        assert len(ax1.get_xticks()) == 2
        assert len(ax2.get_xticks()) == 2
        assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
        assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [2, 3]))

    def test_facet_categories_single_dim_shared(self):

        data = [
            ("a", 1, 1), ("b", 1, 1),
            ("a", 1, 2), ("c", 1, 2),
            ("b", 2, 1), ("d", 2, 1),
            ("e", 2, 2), ("e", 2, 1),
        ]
        df = pd.DataFrame(data, columns=["x", "row", "col"]).assign(y=1)
        m = MockMark()
        p = (
            Plot(df, x="x")
            .facet(row="row", col="col")
            .add(m)
            .share(x="row")
            .plot()
        )

        axs = p._figure.axes
        for ax in axs:
            assert ax.get_xticks() == [0, 1, 2]

        assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
        assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [2, 3]))
        assert_vector_equal(m.passed_data[2]["x"], pd.Series([0., 1., 2.], [4, 5, 7]))
        assert_vector_equal(m.passed_data[3]["x"], pd.Series([2.], [6]))

    def test_pair_categories(self):

        data = [("a", "a"), ("b", "c")]
        df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1)
        m = MockMark()
        p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).plot()

        ax1, ax2 = p._figure.axes
        assert ax1.get_xticks() == [0, 1]
        assert ax2.get_xticks() == [0, 1]
        assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
        assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 1.], [0, 1]))

    def test_pair_categories_shared(self):

        data = [("a", "a"), ("b", "c")]
        df = pd.DataFrame(data, columns=["x1", "x2"]).assign(y=1)
        m = MockMark()
        p = Plot(df, y="y").pair(x=["x1", "x2"]).add(m).share(x=True).plot()

        for ax in p._figure.axes:
            assert ax.get_xticks() == [0, 1, 2]
        print(m.passed_data)
        assert_vector_equal(m.passed_data[0]["x"], pd.Series([0., 1.], [0, 1]))
        assert_vector_equal(m.passed_data[1]["x"], pd.Series([0., 2.], [0, 1]))

    def test_identity_mapping_linewidth(self):

        m = MockMark()
        x = y = [1, 2, 3, 4, 5]
        lw = pd.Series([.5, .1, .1, .9, 3])
        Plot(x=x, y=y, linewidth=lw).scale(linewidth=None).add(m).plot()
        assert_vector_equal(m.passed_scales["linewidth"](lw), lw)

    def test_pair_single_coordinate_stat_orient(self, long_df):

        class MockStat(Stat):
            def __call__(self, data, groupby, orient, scales):
                self.orient = orient
                return data

        s = MockStat()
        Plot(long_df).pair(x=["x", "y"]).add(MockMark(), s).plot()
        assert s.orient == "x"

    def test_inferred_nominal_passed_to_stat(self):

        class MockStat(Stat):
            def __call__(self, data, groupby, orient, scales):
                self.scales = scales
                return data

        s = MockStat()
        y = ["a", "a", "b", "c"]
        Plot(y=y).add(MockMark(), s).plot()
        assert s.scales["y"].__class__.__name__ == "Nominal"

    # TODO where should RGB consistency be enforced?
    @pytest.mark.xfail(
        reason="Correct output representation for color with identity scale undefined"
    )
    def test_identity_mapping_color_strings(self):

        m = MockMark()
        x = y = [1, 2, 3]
        c = ["C0", "C2", "C1"]
        Plot(x=x, y=y, color=c).scale(color=None).add(m).plot()
        expected = mpl.colors.to_rgba_array(c)[:, :3]
        assert_array_equal(m.passed_scales["color"](c), expected)

    def test_identity_mapping_color_tuples(self):

        m = MockMark()
        x = y = [1, 2, 3]
        c = [(1, 0, 0), (0, 1, 0), (1, 0, 0)]
        Plot(x=x, y=y, color=c).scale(color=None).add(m).plot()
        expected = mpl.colors.to_rgba_array(c)[:, :3]
        assert_array_equal(m.passed_scales["color"](c), expected)

    @pytest.mark.xfail(
        reason="Need decision on what to do with scale defined for unused variable"
    )
    def test_undefined_variable_raises(self):

        p = Plot(x=[1, 2, 3], color=["a", "b", "c"]).scale(y=Continuous())
        err = r"No data found for variable\(s\) with explicit scale: {'y'}"
        with pytest.raises(RuntimeError, match=err):
            p.plot()

    def test_nominal_x_axis_tweaks(self):

        p = Plot(x=["a", "b", "c"], y=[1, 2, 3])
        ax1 = p.plot()._figure.axes[0]
        assert ax1.get_xlim() == (-.5, 2.5)
        assert not any(x.get_visible() for x in ax1.xaxis.get_gridlines())

        lim = (-1, 2.1)
        ax2 = p.limit(x=lim).plot()._figure.axes[0]
        assert ax2.get_xlim() == lim

    def test_nominal_y_axis_tweaks(self):

        p = Plot(x=[1, 2, 3], y=["a", "b", "c"])
        ax1 = p.plot()._figure.axes[0]
        assert ax1.get_ylim() == (2.5, -.5)
        assert not any(y.get_visible() for y in ax1.yaxis.get_gridlines())

        lim = (-1, 2.1)
        ax2 = p.limit(y=lim).plot()._figure.axes[0]
        assert ax2.get_ylim() == lim


class TestPlotting:

    def test_matplotlib_object_creation(self):

        p = Plot().plot()
        assert isinstance(p._figure, mpl.figure.Figure)
        for sub in p._subplots:
            assert isinstance(sub["ax"], mpl.axes.Axes)

    def test_empty(self):

        m = MockMark()
        Plot().add(m).plot()
        assert m.n_splits == 0
        assert not m.passed_data

    def test_no_orient_variance(self):

        x, y = [0, 0], [1, 2]
        m = MockMark()
        Plot(x, y).add(m).plot()
        assert_array_equal(m.passed_data[0]["x"], x)
        assert_array_equal(m.passed_data[0]["y"], y)

    def test_single_split_single_layer(self, long_df):

        m = MockMark()
        p = Plot(long_df, x="f", y="z").add(m).plot()
        assert m.n_splits == 1

        assert m.passed_keys[0] == {}
        assert m.passed_axes == [sub["ax"] for sub in p._subplots]
        for col in p._data.frame:
            assert_series_equal(m.passed_data[0][col], p._data.frame[col])

    def test_single_split_multi_layer(self, long_df):

        vs = [{"color": "a", "linewidth": "z"}, {"color": "b", "pattern": "c"}]

        class NoGroupingMark(MockMark):
            _grouping_props = []

        ms = [NoGroupingMark(), NoGroupingMark()]
        Plot(long_df).add(ms[0], **vs[0]).add(ms[1], **vs[1]).plot()

        for m, v in zip(ms, vs):
            for var, col in v.items():
                assert_vector_equal(m.passed_data[0][var], long_df[col])

    def check_splits_single_var(
        self, data, mark, data_vars, split_var, split_col, split_keys
    ):

        assert mark.n_splits == len(split_keys)
        assert mark.passed_keys == [{split_var: key} for key in split_keys]

        for i, key in enumerate(split_keys):

            split_data = data[data[split_col] == key]
            for var, col in data_vars.items():
                assert_array_equal(mark.passed_data[i][var], split_data[col])

    def check_splits_multi_vars(
        self, data, mark, data_vars, split_vars, split_cols, split_keys
    ):

        assert mark.n_splits == np.prod([len(ks) for ks in split_keys])

        expected_keys = [
            dict(zip(split_vars, level_keys))
            for level_keys in itertools.product(*split_keys)
        ]
        assert mark.passed_keys == expected_keys

        for i, keys in enumerate(itertools.product(*split_keys)):

            use_rows = pd.Series(True, data.index)
            for var, col, key in zip(split_vars, split_cols, keys):
                use_rows &= data[col] == key
            split_data = data[use_rows]
            for var, col in data_vars.items():
                assert_array_equal(mark.passed_data[i][var], split_data[col])

    @pytest.mark.parametrize(
        "split_var", [
            "color",  # explicitly declared on the Mark
            "group",  # implicitly used for all Mark classes
        ])
    def test_one_grouping_variable(self, long_df, split_var):

        split_col = "a"
        data_vars = {"x": "f", "y": "z", split_var: split_col}

        m = MockMark()
        p = Plot(long_df, **data_vars).add(m).plot()

        split_keys = categorical_order(long_df[split_col])
        sub, *_ = p._subplots
        assert m.passed_axes == [sub["ax"] for _ in split_keys]
        self.check_splits_single_var(
            long_df, m, data_vars, split_var, split_col, split_keys
        )

    def test_two_grouping_variables(self, long_df):

        split_vars = ["color", "group"]
        split_cols = ["a", "b"]
        data_vars = {"y": "z", **{var: col for var, col in zip(split_vars, split_cols)}}

        m = MockMark()
        p = Plot(long_df, **data_vars).add(m).plot()

        split_keys = [categorical_order(long_df[col]) for col in split_cols]
        sub, *_ = p._subplots
        assert m.passed_axes == [
            sub["ax"] for _ in itertools.product(*split_keys)
        ]
        self.check_splits_multi_vars(
            long_df, m, data_vars, split_vars, split_cols, split_keys
        )

    def test_specified_width(self, long_df):

        m = MockMark()
        Plot(long_df, x="x", y="y").add(m, width="z").plot()
        assert_array_almost_equal(m.passed_data[0]["width"], long_df["z"])

    def test_facets_no_subgroups(self, long_df):

        split_var = "col"
        split_col = "b"
        data_vars = {"x": "f", "y": "z"}

        m = MockMark()
        p = Plot(long_df, **data_vars).facet(**{split_var: split_col}).add(m).plot()

        split_keys = categorical_order(long_df[split_col])
        assert m.passed_axes == list(p._figure.axes)
        self.check_splits_single_var(
            long_df, m, data_vars, split_var, split_col, split_keys
        )

    def test_facets_one_subgroup(self, long_df):

        facet_var, facet_col = fx = "col", "a"
        group_var, group_col = gx = "group", "b"
        split_vars, split_cols = zip(*[fx, gx])
        data_vars = {"x": "f", "y": "z", group_var: group_col}

        m = MockMark()
        p = (
            Plot(long_df, **data_vars)
            .facet(**{facet_var: facet_col})
            .add(m)
            .plot()
        )

        split_keys = [categorical_order(long_df[col]) for col in [facet_col, group_col]]
        assert m.passed_axes == [
            ax
            for ax in list(p._figure.axes)
            for _ in categorical_order(long_df[group_col])
        ]
        self.check_splits_multi_vars(
            long_df, m, data_vars, split_vars, split_cols, split_keys
        )

    def test_layer_specific_facet_disabling(self, long_df):

        axis_vars = {"x": "y", "y": "z"}
        row_var = "a"

        m = MockMark()
        p = Plot(long_df, **axis_vars).facet(row=row_var).add(m, row=None).plot()

        col_levels = categorical_order(long_df[row_var])
        assert len(p._figure.axes) == len(col_levels)

        for data in m.passed_data:
            for var, col in axis_vars.items():
                assert_vector_equal(data[var], long_df[col])

    def test_paired_variables(self, long_df):

        x = ["x", "y"]
        y = ["f", "z"]

        m = MockMark()
        Plot(long_df).pair(x, y).add(m).plot()

        var_product = itertools.product(x, y)

        for data, (x_i, y_i) in zip(m.passed_data, var_product):
            assert_vector_equal(data["x"], long_df[x_i].astype(float))
            assert_vector_equal(data["y"], long_df[y_i].astype(float))

    def test_paired_one_dimension(self, long_df):

        x = ["y", "z"]

        m = MockMark()
        Plot(long_df).pair(x).add(m).plot()

        for data, x_i in zip(m.passed_data, x):
            assert_vector_equal(data["x"], long_df[x_i].astype(float))

    def test_paired_variables_one_subset(self, long_df):

        x = ["x", "y"]
        y = ["f", "z"]
        group = "a"

        long_df["x"] = long_df["x"].astype(float)  # simplify vector comparison

        m = MockMark()
        Plot(long_df, group=group).pair(x, y).add(m).plot()

        groups = categorical_order(long_df[group])
        var_product = itertools.product(x, y, groups)

        for data, (x_i, y_i, g_i) in zip(m.passed_data, var_product):
            rows = long_df[group] == g_i
            assert_vector_equal(data["x"], long_df.loc[rows, x_i])
            assert_vector_equal(data["y"], long_df.loc[rows, y_i])

    def test_paired_and_faceted(self, long_df):

        x = ["y", "z"]
        y = "f"
        row = "c"

        m = MockMark()
        Plot(long_df, y=y).facet(row=row).pair(x).add(m).plot()

        facets = categorical_order(long_df[row])
        var_product = itertools.product(x, facets)

        for data, (x_i, f_i) in zip(m.passed_data, var_product):
            rows = long_df[row] == f_i
            assert_vector_equal(data["x"], long_df.loc[rows, x_i])
            assert_vector_equal(data["y"], long_df.loc[rows, y])

    def test_theme_default(self):

        p = Plot().plot()
        assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), "#EAEAF2")

    def test_theme_params(self):

        color = ".888"
        p = Plot().theme({"axes.facecolor": color}).plot()
        assert mpl.colors.same_color(p._figure.axes[0].get_facecolor(), color)

    def test_theme_error(self):

        p = Plot()
        with pytest.raises(TypeError, match=r"theme\(\) takes 2 positional"):
            p.theme("arg1", "arg2")

    def test_theme_validation(self):

        p = Plot()
        # You'd think matplotlib would raise a TypeError here, but it doesn't
        with pytest.raises(ValueError, match="Key axes.linewidth:"):
            p.theme({"axes.linewidth": "thick"})

        with pytest.raises(KeyError, match="not.a.key is not a valid rc"):
            p.theme({"not.a.key": True})

    def test_stat(self, long_df):

        orig_df = long_df.copy(deep=True)

        m = MockMark()
        Plot(long_df, x="a", y="z").add(m, Agg()).plot()

        expected = long_df.groupby("a", sort=False)["z"].mean().reset_index(drop=True)
        assert_vector_equal(m.passed_data[0]["y"], expected)

        assert_frame_equal(long_df, orig_df)   # Test data was not mutated

    def test_move(self, long_df):

        orig_df = long_df.copy(deep=True)

        m = MockMark()
        Plot(long_df, x="z", y="z").add(m, Shift(x=1)).plot()
        assert_vector_equal(m.passed_data[0]["x"], long_df["z"] + 1)
        assert_vector_equal(m.passed_data[0]["y"], long_df["z"])

        assert_frame_equal(long_df, orig_df)   # Test data was not mutated

    def test_stat_and_move(self, long_df):

        m = MockMark()
        Plot(long_df, x="a", y="z").add(m, Agg(), Shift(y=1)).plot()

        expected = long_df.groupby("a", sort=False)["z"].mean().reset_index(drop=True)
        assert_vector_equal(m.passed_data[0]["y"], expected + 1)

    def test_stat_log_scale(self, long_df):

        orig_df = long_df.copy(deep=True)

        m = MockMark()
        Plot(long_df, x="a", y="z").add(m, Agg()).scale(y="log").plot()

        x = long_df["a"]
        y = np.log10(long_df["z"])
        expected = y.groupby(x, sort=False).mean().reset_index(drop=True)
        assert_vector_equal(m.passed_data[0]["y"], 10 ** expected)

        assert_frame_equal(long_df, orig_df)   # Test data was not mutated

    def test_move_log_scale(self, long_df):

        m = MockMark()
        Plot(
            long_df, x="z", y="z"
        ).scale(x="log").add(m, Shift(x=-1)).plot()
        assert_vector_equal(m.passed_data[0]["x"], long_df["z"] / 10)

    def test_multi_move(self, long_df):

        m = MockMark()
        move_stack = [Shift(1), Shift(2)]
        Plot(long_df, x="x", y="y").add(m, *move_stack).plot()
        assert_vector_equal(m.passed_data[0]["x"], long_df["x"] + 3)

    def test_multi_move_with_pairing(self, long_df):
        m = MockMark()
        move_stack = [Shift(1), Shift(2)]
        Plot(long_df, x="x").pair(y=["y", "z"]).add(m, *move_stack).plot()
        for frame in m.passed_data:
            assert_vector_equal(frame["x"], long_df["x"] + 3)

    def test_move_with_range(self, long_df):

        x = [0, 0, 1, 1, 2, 2]
        group = [0, 1, 0, 1, 0, 1]
        ymin = np.arange(6)
        ymax = np.arange(6) * 2

        m = MockMark()
        Plot(x=x, group=group, ymin=ymin, ymax=ymax).add(m, Dodge()).plot()

        signs = [-1, +1]
        for i, df in m.passed_data[0].groupby("group"):
            assert_array_equal(df["x"], np.arange(3) + signs[i] * 0.2)

    def test_methods_clone(self, long_df):

        p1 = Plot(long_df, "x", "y")
        p2 = p1.add(MockMark()).facet("a")

        assert p1 is not p2
        assert not p1._layers
        assert not p1._facet_spec

    def test_default_is_no_pyplot(self):

        p = Plot().plot()

        assert not plt.get_fignums()
        assert isinstance(p._figure, mpl.figure.Figure)

    def test_with_pyplot(self):

        p = Plot().plot(pyplot=True)

        assert len(plt.get_fignums()) == 1
        fig = plt.gcf()
        assert p._figure is fig

    def test_show(self):

        p = Plot()

        with warnings.catch_warnings(record=True) as msg:
            out = p.show(block=False)
        assert out is None
        assert not hasattr(p, "_figure")

        assert len(plt.get_fignums()) == 1
        fig = plt.gcf()

        gui_backend = (
            # From https://github.com/matplotlib/matplotlib/issues/20281
            fig.canvas.manager.show != mpl.backend_bases.FigureManagerBase.show
        )
        if not gui_backend:
            assert msg

    def test_save(self):

        buf = io.BytesIO()

        p = Plot().save(buf)
        assert isinstance(p, Plot)
        img = Image.open(buf)
        assert img.format == "PNG"

        buf = io.StringIO()
        Plot().save(buf, format="svg")
        tag = xml.etree.ElementTree.fromstring(buf.getvalue()).tag
        assert tag == "{http://www.w3.org/2000/svg}svg"

    def test_layout_size(self):

        size = (4, 2)
        p = Plot().layout(size=size).plot()
        assert tuple(p._figure.get_size_inches()) == size

    @pytest.mark.skipif(
        _version_predates(mpl, "3.6"),
        reason="mpl<3.6 does not have get_layout_engine",
    )
    def test_layout_extent(self):

        p = Plot().layout(extent=(.1, .2, .6, 1)).plot()
        assert p._figure.get_layout_engine().get()["rect"] == [.1, .2, .5, .8]

    @pytest.mark.skipif(
        _version_predates(mpl, "3.6"),
        reason="mpl<3.6 does not have get_layout_engine",
    )
    def test_constrained_layout_extent(self):

        p = Plot().layout(engine="constrained", extent=(.1, .2, .6, 1)).plot()
        assert p._figure.get_layout_engine().get()["rect"] == [.1, .2, .5, .8]

    def test_base_layout_extent(self):

        p = Plot().layout(engine=None, extent=(.1, .2, .6, 1)).plot()
        assert p._figure.subplotpars.left == 0.1
        assert p._figure.subplotpars.right == 0.6
        assert p._figure.subplotpars.bottom == 0.2
        assert p._figure.subplotpars.top == 1

    def test_on_axes(self):

        ax = mpl.figure.Figure().subplots()
        m = MockMark()
        p = Plot([1], [2]).on(ax).add(m).plot()
        assert m.passed_axes == [ax]
        assert p._figure is ax.figure

    @pytest.mark.parametrize("facet", [True, False])
    def test_on_figure(self, facet):

        f = mpl.figure.Figure()
        m = MockMark()
        p = Plot([1, 2], [3, 4]).on(f).add(m)
        if facet:
            p = p.facet(["a", "b"])
        p = p.plot()
        assert m.passed_axes == f.axes
        assert p._figure is f

    @pytest.mark.parametrize("facet", [True, False])
    def test_on_subfigure(self, facet):

        sf1, sf2 = mpl.figure.Figure().subfigures(2)
        sf1.subplots()
        m = MockMark()
        p = Plot([1, 2], [3, 4]).on(sf2).add(m)
        if facet:
            p = p.facet(["a", "b"])
        p = p.plot()
        assert m.passed_axes == sf2.figure.axes[1:]
        assert p._figure is sf2.figure

    def test_on_type_check(self):

        p = Plot()
        with pytest.raises(TypeError, match="The `Plot.on`.+"):
            p.on([])

    def test_on_axes_with_subplots_error(self):

        ax = mpl.figure.Figure().subplots()

        p1 = Plot().facet(["a", "b"]).on(ax)
        with pytest.raises(RuntimeError, match="Cannot create multiple subplots"):
            p1.plot()

        p2 = Plot().pair([["a", "b"], ["x", "y"]]).on(ax)
        with pytest.raises(RuntimeError, match="Cannot create multiple subplots"):
            p2.plot()

    @pytest.mark.skipif(
        _version_predates(mpl, "3.6"),
        reason="Requires newer matplotlib layout engine API"
    )
    def test_on_layout_algo_default(self):

        class MockEngine(mpl.layout_engine.ConstrainedLayoutEngine):
            ...

        f = mpl.figure.Figure(layout=MockEngine())
        p = Plot().on(f).plot()
        layout_engine = p._figure.get_layout_engine()
        assert layout_engine.__class__.__name__ == "MockEngine"

    @pytest.mark.skipif(
        _version_predates(mpl, "3.6"),
        reason="Requires newer matplotlib layout engine API"
    )
    def test_on_layout_algo_spec(self):

        f = mpl.figure.Figure(layout="constrained")
        p = Plot().on(f).layout(engine="tight").plot()
        layout_engine = p._figure.get_layout_engine()
        assert layout_engine.__class__.__name__ == "TightLayoutEngine"

    def test_axis_labels_from_constructor(self, long_df):

        ax, = Plot(long_df, x="a", y="b").plot()._figure.axes
        assert ax.get_xlabel() == "a"
        assert ax.get_ylabel() == "b"

        ax, = Plot(x=long_df["a"], y=long_df["b"].to_numpy()).plot()._figure.axes
        assert ax.get_xlabel() == "a"
        assert ax.get_ylabel() == ""

    def test_axis_labels_from_layer(self, long_df):

        m = MockMark()

        ax, = Plot(long_df).add(m, x="a", y="b").plot()._figure.axes
        assert ax.get_xlabel() == "a"
        assert ax.get_ylabel() == "b"

        p = Plot().add(m, x=long_df["a"], y=long_df["b"].to_list())
        ax, = p.plot()._figure.axes
        assert ax.get_xlabel() == "a"
        assert ax.get_ylabel() == ""

    def test_axis_labels_are_first_name(self, long_df):

        m = MockMark()
        p = (
            Plot(long_df, x=long_df["z"].to_list(), y="b")
            .add(m, x="a")
            .add(m, x="x", y="y")
        )
        ax, = p.plot()._figure.axes
        assert ax.get_xlabel() == "a"
        assert ax.get_ylabel() == "b"

    def test_limits(self, long_df):

        limit = (-2, 24)
        p = Plot(long_df, x="x", y="y").limit(x=limit).plot()
        ax = p._figure.axes[0]
        assert ax.get_xlim() == limit

        limit = (np.datetime64("2005-01-01"), np.datetime64("2008-01-01"))
        p = Plot(long_df, x="d", y="y").limit(x=limit).plot()
        ax = p._figure.axes[0]
        assert ax.get_xlim() == tuple(mpl.dates.date2num(limit))

        limit = ("b", "c")
        p = Plot(x=["a", "b", "c", "d"], y=[1, 2, 3, 4]).limit(x=limit).plot()
        ax = p._figure.axes[0]
        assert ax.get_xlim() == (0.5, 2.5)

    def test_labels_axis(self, long_df):

        label = "Y axis"
        p = Plot(long_df, x="x", y="y").label(y=label).plot()
        ax = p._figure.axes[0]
        assert ax.get_ylabel() == label

        label = str.capitalize
        p = Plot(long_df, x="x", y="y").label(y=label).plot()
        ax = p._figure.axes[0]
        assert ax.get_ylabel() == "Y"

    def test_labels_legend(self, long_df):

        m = MockMark()

        label = "A"
        p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=label).plot()
        assert p._figure.legends[0].get_title().get_text() == label

        func = str.capitalize
        p = Plot(long_df, x="x", y="y", color="a").add(m).label(color=func).plot()
        assert p._figure.legends[0].get_title().get_text() == label

    def test_labels_facets(self):

        data = {"a": ["b", "c"], "x": ["y", "z"]}
        p = Plot(data).facet("a", "x").label(col=str.capitalize, row="$x$").plot()
        axs = np.reshape(p._figure.axes, (2, 2))
        for (i, j), ax in np.ndenumerate(axs):
            expected = f"A {data['a'][j]} | $x$ {data['x'][i]}"
            assert ax.get_title() == expected

    def test_title_single(self):

        label = "A"
        p = Plot().label(title=label).plot()
        assert p._figure.axes[0].get_title() == label

    def test_title_facet_function(self):

        titles = ["a", "b"]
        p = Plot().facet(titles).label(title=str.capitalize).plot()
        for i, ax in enumerate(p._figure.axes):
            assert ax.get_title() == titles[i].upper()

        cols, rows = ["a", "b"], ["x", "y"]
        p = Plot().facet(cols, rows).label(title=str.capitalize).plot()
        for i, ax in enumerate(p._figure.axes):
            expected = " | ".join([cols[i % 2].upper(), rows[i // 2].upper()])
            assert ax.get_title() == expected


class TestExceptions:

    def test_scale_setup(self):

        x = y = color = ["a", "b"]
        bad_palette = "not_a_palette"
        p = Plot(x, y, color=color).add(MockMark()).scale(color=bad_palette)

        msg = "Scale setup failed for the `color` variable."
        with pytest.raises(PlotSpecError, match=msg) as err:
            p.plot()
        assert isinstance(err.value.__cause__, ValueError)
        assert bad_palette in str(err.value.__cause__)

    def test_coordinate_scaling(self):

        x = ["a", "b"]
        y = [1, 2]
        p = Plot(x, y).add(MockMark()).scale(x=Temporal())

        msg = "Scaling operation failed for the `x` variable."
        with pytest.raises(PlotSpecError, match=msg) as err:
            p.plot()
        # Don't test the cause contents b/c matplotlib owns them here.
        assert hasattr(err.value, "__cause__")

    def test_semantic_scaling(self):

        class ErrorRaising(Continuous):

            def _setup(self, data, prop, axis=None):

                def f(x):
                    raise ValueError("This is a test")

                new = super()._setup(data, prop, axis)
                new._pipeline = [f]
                return new

        x = y = color = [1, 2]
        p = Plot(x, y, color=color).add(Dot()).scale(color=ErrorRaising())
        msg = "Scaling operation failed for the `color` variable."
        with pytest.raises(PlotSpecError, match=msg) as err:
            p.plot()
        assert isinstance(err.value.__cause__, ValueError)
        assert str(err.value.__cause__) == "This is a test"


class TestFacetInterface:

    @pytest.fixture(scope="class", params=["row", "col"])
    def dim(self, request):
        return request.param

    @pytest.fixture(scope="class", params=["reverse", "subset", "expand"])
    def reorder(self, request):
        return {
            "reverse": lambda x: x[::-1],
            "subset": lambda x: x[:-1],
            "expand": lambda x: x + ["z"],
        }[request.param]

    def check_facet_results_1d(self, p, df, dim, key, order=None):

        p = p.plot()

        order = categorical_order(df[key], order)
        assert len(p._figure.axes) == len(order)

        other_dim = {"row": "col", "col": "row"}[dim]

        for subplot, level in zip(p._subplots, order):
            assert subplot[dim] == level
            assert subplot[other_dim] is None
            assert subplot["ax"].get_title() == f"{level}"
            assert_gridspec_shape(subplot["ax"], **{f"n{dim}s": len(order)})

    def test_1d(self, long_df, dim):

        key = "a"
        p = Plot(long_df).facet(**{dim: key})
        self.check_facet_results_1d(p, long_df, dim, key)

    def test_1d_as_vector(self, long_df, dim):

        key = "a"
        p = Plot(long_df).facet(**{dim: long_df[key]})
        self.check_facet_results_1d(p, long_df, dim, key)

    def test_1d_with_order(self, long_df, dim, reorder):

        key = "a"
        order = reorder(categorical_order(long_df[key]))
        p = Plot(long_df).facet(**{dim: key, "order": order})
        self.check_facet_results_1d(p, long_df, dim, key, order)

    def check_facet_results_2d(self, p, df, variables, order=None):

        p = p.plot()

        if order is None:
            order = {dim: categorical_order(df[key]) for dim, key in variables.items()}

        levels = itertools.product(*[order[dim] for dim in ["row", "col"]])
        assert len(p._subplots) == len(list(levels))

        for subplot, (row_level, col_level) in zip(p._subplots, levels):
            assert subplot["row"] == row_level
            assert subplot["col"] == col_level
            assert subplot["axes"].get_title() == (
                f"{col_level} | {row_level}"
            )
            assert_gridspec_shape(
                subplot["axes"], len(levels["row"]), len(levels["col"])
            )

    def test_2d(self, long_df):

        variables = {"row": "a", "col": "c"}
        p = Plot(long_df).facet(**variables)
        self.check_facet_results_2d(p, long_df, variables)

    def test_2d_with_order(self, long_df, reorder):

        variables = {"row": "a", "col": "c"}
        order = {
            dim: reorder(categorical_order(long_df[key]))
            for dim, key in variables.items()
        }

        p = Plot(long_df).facet(**variables, order=order)
        self.check_facet_results_2d(p, long_df, variables, order)

    @pytest.mark.parametrize("algo", ["tight", "constrained"])
    def test_layout_algo(self, algo):

        p = Plot().facet(["a", "b"]).limit(x=(.1, .9))

        p1 = p.layout(engine=algo).plot()
        p2 = p.layout(engine="none").plot()

        # Force a draw (we probably need a method for this)
        p1.save(io.BytesIO())
        p2.save(io.BytesIO())

        bb11, bb12 = [ax.get_position() for ax in p1._figure.axes]
        bb21, bb22 = [ax.get_position() for ax in p2._figure.axes]

        sep1 = bb12.corners()[0, 0] - bb11.corners()[2, 0]
        sep2 = bb22.corners()[0, 0] - bb21.corners()[2, 0]
        assert sep1 <= sep2

    def test_axis_sharing(self, long_df):

        variables = {"row": "a", "col": "c"}

        p = Plot(long_df).facet(**variables)

        p1 = p.plot()
        root, *other = p1._figure.axes
        for axis in "xy":
            shareset = getattr(root, f"get_shared_{axis}_axes")()
            assert all(shareset.joined(root, ax) for ax in other)

        p2 = p.share(x=False, y=False).plot()
        root, *other = p2._figure.axes
        for axis in "xy":
            shareset = getattr(root, f"get_shared_{axis}_axes")()
            assert not any(shareset.joined(root, ax) for ax in other)

        p3 = p.share(x="col", y="row").plot()
        shape = (
            len(categorical_order(long_df[variables["row"]])),
            len(categorical_order(long_df[variables["col"]])),
        )
        axes_matrix = np.reshape(p3._figure.axes, shape)

        for (shared, unshared), vectors in zip(
            ["yx", "xy"], [axes_matrix, axes_matrix.T]
        ):
            for root, *other in vectors:
                shareset = {
                    axis: getattr(root, f"get_shared_{axis}_axes")() for axis in "xy"
                }
                assert all(shareset[shared].joined(root, ax) for ax in other)
                assert not any(shareset[unshared].joined(root, ax) for ax in other)

    def test_unshared_spacing(self):

        x = [1, 2, 10, 20]
        y = [1, 2, 3, 4]
        col = [1, 1, 2, 2]

        m = MockMark()
        Plot(x, y).facet(col).add(m).share(x=False).plot()
        assert_array_almost_equal(m.passed_data[0]["width"], [0.8, 0.8])
        assert_array_equal(m.passed_data[1]["width"], [8, 8])

    def test_col_wrapping(self):

        cols = list("abcd")
        wrap = 3
        p = Plot().facet(col=cols, wrap=wrap).plot()

        assert len(p._figure.axes) == 4
        assert_gridspec_shape(p._figure.axes[0], len(cols) // wrap + 1, wrap)

        # TODO test axis labels and titles

    def test_row_wrapping(self):

        rows = list("abcd")
        wrap = 3
        p = Plot().facet(row=rows, wrap=wrap).plot()

        assert_gridspec_shape(p._figure.axes[0], wrap, len(rows) // wrap + 1)
        assert len(p._figure.axes) == 4

        # TODO test axis labels and titles


class TestPairInterface:

    def check_pair_grid(self, p, x, y):

        xys = itertools.product(y, x)

        for (y_i, x_j), subplot in zip(xys, p._subplots):

            ax = subplot["ax"]
            assert ax.get_xlabel() == "" if x_j is None else x_j
            assert ax.get_ylabel() == "" if y_i is None else y_i
            assert_gridspec_shape(subplot["ax"], len(y), len(x))

    @pytest.mark.parametrize("vector_type", [list, pd.Index])
    def test_all_numeric(self, long_df, vector_type):

        x, y = ["x", "y", "z"], ["s", "f"]
        p = Plot(long_df).pair(vector_type(x), vector_type(y)).plot()
        self.check_pair_grid(p, x, y)

    def test_single_variable_key_raises(self, long_df):

        p = Plot(long_df)
        err = "You must pass a sequence of variable keys to `y`"
        with pytest.raises(TypeError, match=err):
            p.pair(x=["x", "y"], y="z")

    @pytest.mark.parametrize("dim", ["x", "y"])
    def test_single_dimension(self, long_df, dim):

        variables = {"x": None, "y": None}
        variables[dim] = ["x", "y", "z"]
        p = Plot(long_df).pair(**variables).plot()
        variables = {k: [v] if v is None else v for k, v in variables.items()}
        self.check_pair_grid(p, **variables)

    def test_non_cross(self, long_df):

        x = ["x", "y"]
        y = ["f", "z"]

        p = Plot(long_df).pair(x, y, cross=False).plot()

        for i, subplot in enumerate(p._subplots):
            ax = subplot["ax"]
            assert ax.get_xlabel() == x[i]
            assert ax.get_ylabel() == y[i]
            assert_gridspec_shape(ax, 1, len(x))

        root, *other = p._figure.axes
        for axis in "xy":
            shareset = getattr(root, f"get_shared_{axis}_axes")()
            assert not any(shareset.joined(root, ax) for ax in other)

    def test_list_of_vectors(self, long_df):

        x_vars = ["x", "z"]
        p = Plot(long_df, y="y").pair(x=[long_df[x] for x in x_vars]).plot()
        assert len(p._figure.axes) == len(x_vars)
        for ax, x_i in zip(p._figure.axes, x_vars):
            assert ax.get_xlabel() == x_i

    def test_with_no_variables(self, long_df):

        p = Plot(long_df).pair().plot()
        assert len(p._figure.axes) == 1

    def test_with_facets(self, long_df):

        x = "x"
        y = ["y", "z"]
        col = "a"

        p = Plot(long_df, x=x).facet(col).pair(y=y).plot()

        facet_levels = categorical_order(long_df[col])
        dims = itertools.product(y, facet_levels)

        for (y_i, col_i), subplot in zip(dims, p._subplots):

            ax = subplot["ax"]
            assert ax.get_xlabel() == x
            assert ax.get_ylabel() == y_i
            assert ax.get_title() == f"{col_i}"
            assert_gridspec_shape(ax, len(y), len(facet_levels))

    @pytest.mark.parametrize("variables", [("rows", "y"), ("columns", "x")])
    def test_error_on_facet_overlap(self, long_df, variables):

        facet_dim, pair_axis = variables
        p = Plot(long_df).facet(**{facet_dim[:3]: "a"}).pair(**{pair_axis: ["x", "y"]})
        expected = f"Cannot facet the {facet_dim} while pairing on `{pair_axis}`."
        with pytest.raises(RuntimeError, match=expected):
            p.plot()

    @pytest.mark.parametrize("variables", [("columns", "y"), ("rows", "x")])
    def test_error_on_wrap_overlap(self, long_df, variables):

        facet_dim, pair_axis = variables
        p = (
            Plot(long_df)
            .facet(wrap=2, **{facet_dim[:3]: "a"})
            .pair(**{pair_axis: ["x", "y"]})
        )
        expected = f"Cannot wrap the {facet_dim} while pairing on `{pair_axis}``."
        with pytest.raises(RuntimeError, match=expected):
            p.plot()

    def test_axis_sharing(self, long_df):

        p = Plot(long_df).pair(x=["a", "b"], y=["y", "z"])
        shape = 2, 2

        p1 = p.plot()
        axes_matrix = np.reshape(p1._figure.axes, shape)

        for root, *other in axes_matrix:  # Test row-wise sharing
            x_shareset = getattr(root, "get_shared_x_axes")()
            assert not any(x_shareset.joined(root, ax) for ax in other)
            y_shareset = getattr(root, "get_shared_y_axes")()
            assert all(y_shareset.joined(root, ax) for ax in other)

        for root, *other in axes_matrix.T:  # Test col-wise sharing
            x_shareset = getattr(root, "get_shared_x_axes")()
            assert all(x_shareset.joined(root, ax) for ax in other)
            y_shareset = getattr(root, "get_shared_y_axes")()
            assert not any(y_shareset.joined(root, ax) for ax in other)

        p2 = p.share(x=False, y=False).plot()
        root, *other = p2._figure.axes
        for axis in "xy":
            shareset = getattr(root, f"get_shared_{axis}_axes")()
            assert not any(shareset.joined(root, ax) for ax in other)

    def test_axis_sharing_with_facets(self, long_df):

        p = Plot(long_df, y="y").pair(x=["a", "b"]).facet(row="c").plot()
        shape = 2, 2

        axes_matrix = np.reshape(p._figure.axes, shape)

        for root, *other in axes_matrix:  # Test row-wise sharing
            x_shareset = getattr(root, "get_shared_x_axes")()
            assert not any(x_shareset.joined(root, ax) for ax in other)
            y_shareset = getattr(root, "get_shared_y_axes")()
            assert all(y_shareset.joined(root, ax) for ax in other)

        for root, *other in axes_matrix.T:  # Test col-wise sharing
            x_shareset = getattr(root, "get_shared_x_axes")()
            assert all(x_shareset.joined(root, ax) for ax in other)
            y_shareset = getattr(root, "get_shared_y_axes")()
            assert all(y_shareset.joined(root, ax) for ax in other)

    def test_x_wrapping(self, long_df):

        x_vars = ["f", "x", "y", "z"]
        wrap = 3
        p = Plot(long_df, y="y").pair(x=x_vars, wrap=wrap).plot()

        assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap)
        assert len(p._figure.axes) == len(x_vars)
        for ax, var in zip(p._figure.axes, x_vars):
            label = ax.xaxis.get_label()
            assert label.get_visible()
            assert label.get_text() == var

    def test_y_wrapping(self, long_df):

        y_vars = ["f", "x", "y", "z"]
        wrap = 3
        p = Plot(long_df, x="x").pair(y=y_vars, wrap=wrap).plot()

        n_row, n_col = wrap, len(y_vars) // wrap + 1
        assert_gridspec_shape(p._figure.axes[0], n_row, n_col)
        assert len(p._figure.axes) == len(y_vars)
        label_array = np.empty(n_row * n_col, object)
        label_array[:len(y_vars)] = y_vars
        label_array = label_array.reshape((n_row, n_col), order="F")
        label_array = [y for y in label_array.flat if y is not None]
        for i, ax in enumerate(p._figure.axes):
            label = ax.yaxis.get_label()
            assert label.get_visible()
            assert label.get_text() == label_array[i]

    def test_non_cross_wrapping(self, long_df):

        x_vars = ["a", "b", "c", "t"]
        y_vars = ["f", "x", "y", "z"]
        wrap = 3

        p = (
            Plot(long_df, x="x")
            .pair(x=x_vars, y=y_vars, wrap=wrap, cross=False)
            .plot()
        )

        assert_gridspec_shape(p._figure.axes[0], len(x_vars) // wrap + 1, wrap)
        assert len(p._figure.axes) == len(x_vars)

    def test_cross_mismatched_lengths(self, long_df):

        p = Plot(long_df)
        with pytest.raises(ValueError, match="Lengths of the `x` and `y`"):
            p.pair(x=["a", "b"], y=["x", "y", "z"], cross=False)

    def test_orient_inference(self, long_df):

        orient_list = []

        class CaptureOrientMove(Move):
            def __call__(self, data, groupby, orient, scales):
                orient_list.append(orient)
                return data

        (
            Plot(long_df, x="x")
            .pair(y=["b", "z"])
            .add(MockMark(), CaptureOrientMove())
            .plot()
        )

        assert orient_list == ["y", "x"]

    def test_computed_coordinate_orient_inference(self, long_df):

        class MockComputeStat(Stat):
            def __call__(self, df, groupby, orient, scales):
                other = {"x": "y", "y": "x"}[orient]
                return df.assign(**{other: df[orient] * 2})

        m = MockMark()
        Plot(long_df, y="y").add(m, MockComputeStat()).plot()
        assert m.passed_orient == "y"

    def test_two_variables_single_order_error(self, long_df):

        p = Plot(long_df)
        err = "When faceting on both col= and row=, passing `order`"
        with pytest.raises(RuntimeError, match=err):
            p.facet(col="a", row="b", order=["a", "b", "c"])

    def test_limits(self, long_df):

        lims = (-3, 10), (-2, 24)
        p = Plot(long_df, y="y").pair(x=["x", "z"]).limit(x=lims[0], x1=lims[1]).plot()
        for ax, lim in zip(p._figure.axes, lims):
            assert ax.get_xlim() == lim

    def test_labels(self, long_df):

        label = "zed"
        p = (
            Plot(long_df, y="y")
            .pair(x=["x", "z"])
            .label(x=str.capitalize, x1=label)
        )
        ax0, ax1 = p.plot()._figure.axes
        assert ax0.get_xlabel() == "X"
        assert ax1.get_xlabel() == label


class TestLabelVisibility:

    def has_xaxis_labels(self, ax):
        if _version_predates(mpl, "3.7"):
            # mpl3.7 added a getter for tick params, but both yaxis and xaxis return
            # the same entry of "labelleft" instead of "labelbottom" for xaxis
            return len(ax.get_xticklabels()) > 0
        elif _version_predates(mpl, "3.10"):
            # Then I guess they made it labelbottom in 3.10?
            return ax.xaxis.get_tick_params()["labelleft"]
        else:
            return ax.xaxis.get_tick_params()["labelbottom"]

    def test_single_subplot(self, long_df):

        x, y = "a", "z"
        p = Plot(long_df, x=x, y=y).plot()
        subplot, *_ = p._subplots
        ax = subplot["ax"]
        assert ax.xaxis.get_label().get_visible()
        assert ax.yaxis.get_label().get_visible()
        assert all(t.get_visible() for t in ax.get_xticklabels())
        assert all(t.get_visible() for t in ax.get_yticklabels())

    @pytest.mark.parametrize(
        "facet_kws,pair_kws", [({"col": "b"}, {}), ({}, {"x": ["x", "y", "f"]})]
    )
    def test_1d_column(self, long_df, facet_kws, pair_kws):

        x = None if "x" in pair_kws else "a"
        y = "z"
        p = Plot(long_df, x=x, y=y).plot()
        first, *other = p._subplots

        ax = first["ax"]
        assert ax.xaxis.get_label().get_visible()
        assert ax.yaxis.get_label().get_visible()
        assert all(t.get_visible() for t in ax.get_xticklabels())
        assert all(t.get_visible() for t in ax.get_yticklabels())

        for s in other:
            ax = s["ax"]
            assert ax.xaxis.get_label().get_visible()
            assert not ax.yaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_xticklabels())
            assert not any(t.get_visible() for t in ax.get_yticklabels())

    @pytest.mark.parametrize(
        "facet_kws,pair_kws", [({"row": "b"}, {}), ({}, {"y": ["x", "y", "f"]})]
    )
    def test_1d_row(self, long_df, facet_kws, pair_kws):

        x = "z"
        y = None if "y" in pair_kws else "z"
        p = Plot(long_df, x=x, y=y).plot()
        first, *other = p._subplots

        ax = first["ax"]
        assert ax.xaxis.get_label().get_visible()
        assert all(t.get_visible() for t in ax.get_xticklabels())
        assert ax.yaxis.get_label().get_visible()
        assert all(t.get_visible() for t in ax.get_yticklabels())

        for s in other:
            ax = s["ax"]
            assert not ax.xaxis.get_label().get_visible()
            assert ax.yaxis.get_label().get_visible()
            assert not any(t.get_visible() for t in ax.get_xticklabels())
            assert all(t.get_visible() for t in ax.get_yticklabels())

    def test_1d_column_wrapped(self):

        p = Plot().facet(col=["a", "b", "c", "d"], wrap=3).plot()
        subplots = list(p._subplots)

        for s in [subplots[0], subplots[-1]]:
            ax = s["ax"]
            assert ax.yaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_yticklabels())

        for s in subplots[1:]:
            ax = s["ax"]
            assert ax.xaxis.get_label().get_visible()
            assert self.has_xaxis_labels(ax)
            assert all(t.get_visible() for t in ax.get_xticklabels())

        for s in subplots[1:-1]:
            ax = s["ax"]
            assert not ax.yaxis.get_label().get_visible()
            assert not any(t.get_visible() for t in ax.get_yticklabels())

        ax = subplots[0]["ax"]
        assert not ax.xaxis.get_label().get_visible()
        assert not any(t.get_visible() for t in ax.get_xticklabels())

    def test_1d_row_wrapped(self):

        p = Plot().facet(row=["a", "b", "c", "d"], wrap=3).plot()
        subplots = list(p._subplots)

        for s in subplots[:-1]:
            ax = s["ax"]
            assert ax.yaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_yticklabels())

        for s in subplots[-2:]:
            ax = s["ax"]
            assert ax.xaxis.get_label().get_visible()
            assert self.has_xaxis_labels(ax)
            assert all(t.get_visible() for t in ax.get_xticklabels())

        for s in subplots[:-2]:
            ax = s["ax"]
            assert not ax.xaxis.get_label().get_visible()
            assert not any(t.get_visible() for t in ax.get_xticklabels())

        ax = subplots[-1]["ax"]
        assert not ax.yaxis.get_label().get_visible()
        assert not any(t.get_visible() for t in ax.get_yticklabels())

    def test_1d_column_wrapped_non_cross(self, long_df):

        p = (
            Plot(long_df)
            .pair(x=["a", "b", "c"], y=["x", "y", "z"], wrap=2, cross=False)
            .plot()
        )
        for s in p._subplots:
            ax = s["ax"]
            assert ax.xaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_xticklabels())
            assert ax.yaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_yticklabels())

    def test_2d(self):

        p = Plot().facet(col=["a", "b"], row=["x", "y"]).plot()
        subplots = list(p._subplots)

        for s in subplots[:2]:
            ax = s["ax"]
            assert not ax.xaxis.get_label().get_visible()
            assert not any(t.get_visible() for t in ax.get_xticklabels())

        for s in subplots[2:]:
            ax = s["ax"]
            assert ax.xaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_xticklabels())

        for s in [subplots[0], subplots[2]]:
            ax = s["ax"]
            assert ax.yaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_yticklabels())

        for s in [subplots[1], subplots[3]]:
            ax = s["ax"]
            assert not ax.yaxis.get_label().get_visible()
            assert not any(t.get_visible() for t in ax.get_yticklabels())

    def test_2d_unshared(self):

        p = (
            Plot()
            .facet(col=["a", "b"], row=["x", "y"])
            .share(x=False, y=False)
            .plot()
        )
        subplots = list(p._subplots)

        for s in subplots[:2]:
            ax = s["ax"]
            assert not ax.xaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_xticklabels())

        for s in subplots[2:]:
            ax = s["ax"]
            assert ax.xaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_xticklabels())

        for s in [subplots[0], subplots[2]]:
            ax = s["ax"]
            assert ax.yaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_yticklabels())

        for s in [subplots[1], subplots[3]]:
            ax = s["ax"]
            assert not ax.yaxis.get_label().get_visible()
            assert all(t.get_visible() for t in ax.get_yticklabels())


class TestLegend:

    @pytest.fixture
    def xy(self):
        return dict(x=[1, 2, 3, 4], y=[1, 2, 3, 4])

    def test_single_layer_single_variable(self, xy):

        s = pd.Series(["a", "b", "a", "c"], name="s")
        p = Plot(**xy).add(MockMark(), color=s).plot()
        e, = p._legend_contents

        labels = categorical_order(s)

        assert e[0] == (s.name, s.name)
        assert e[-1] == labels

        artists = e[1]
        assert len(artists) == len(labels)
        for a, label in zip(artists, labels):
            assert isinstance(a, mpl.artist.Artist)
            assert a.value == label
            assert a.variables == ["color"]

    def test_single_layer_common_variable(self, xy):

        s = pd.Series(["a", "b", "a", "c"], name="s")
        sem = dict(color=s, marker=s)
        p = Plot(**xy).add(MockMark(), **sem).plot()
        e, = p._legend_contents

        labels = categorical_order(s)

        assert e[0] == (s.name, s.name)
        assert e[-1] == labels

        artists = e[1]
        assert len(artists) == len(labels)
        for a, label in zip(artists, labels):
            assert isinstance(a, mpl.artist.Artist)
            assert a.value == label
            assert a.variables == list(sem)

    def test_single_layer_common_unnamed_variable(self, xy):

        s = np.array(["a", "b", "a", "c"])
        sem = dict(color=s, marker=s)
        p = Plot(**xy).add(MockMark(), **sem).plot()

        e, = p._legend_contents

        labels = list(np.unique(s))  # assumes sorted order

        assert e[0] == ("", id(s))
        assert e[-1] == labels

        artists = e[1]
        assert len(artists) == len(labels)
        for a, label in zip(artists, labels):
            assert isinstance(a, mpl.artist.Artist)
            assert a.value == label
            assert a.variables == list(sem)

    def test_single_layer_multi_variable(self, xy):

        s1 = pd.Series(["a", "b", "a", "c"], name="s1")
        s2 = pd.Series(["m", "m", "p", "m"], name="s2")
        sem = dict(color=s1, marker=s2)
        p = Plot(**xy).add(MockMark(), **sem).plot()
        e1, e2 = p._legend_contents

        variables = {v.name: k for k, v in sem.items()}

        for e, s in zip([e1, e2], [s1, s2]):
            assert e[0] == (s.name, s.name)

            labels = categorical_order(s)
            assert e[-1] == labels

            artists = e[1]
            assert len(artists) == len(labels)
            for a, label in zip(artists, labels):
                assert isinstance(a, mpl.artist.Artist)
                assert a.value == label
                assert a.variables == [variables[s.name]]

    def test_multi_layer_single_variable(self, xy):

        s = pd.Series(["a", "b", "a", "c"], name="s")
        p = Plot(**xy, color=s).add(MockMark()).add(MockMark()).plot()
        e1, e2 = p._legend_contents

        labels = categorical_order(s)

        for e in [e1, e2]:
            assert e[0] == (s.name, s.name)

            labels = categorical_order(s)
            assert e[-1] == labels

            artists = e[1]
            assert len(artists) == len(labels)
            for a, label in zip(artists, labels):
                assert isinstance(a, mpl.artist.Artist)
                assert a.value == label
                assert a.variables == ["color"]

    def test_multi_layer_multi_variable(self, xy):

        s1 = pd.Series(["a", "b", "a", "c"], name="s1")
        s2 = pd.Series(["m", "m", "p", "m"], name="s2")
        sem = dict(color=s1), dict(marker=s2)
        variables = {"s1": "color", "s2": "marker"}
        p = Plot(**xy).add(MockMark(), **sem[0]).add(MockMark(), **sem[1]).plot()
        e1, e2 = p._legend_contents

        for e, s in zip([e1, e2], [s1, s2]):
            assert e[0] == (s.name, s.name)

            labels = categorical_order(s)
            assert e[-1] == labels

            artists = e[1]
            assert len(artists) == len(labels)
            for a, label in zip(artists, labels):
                assert isinstance(a, mpl.artist.Artist)
                assert a.value == label
                assert a.variables == [variables[s.name]]

    def test_multi_layer_different_artists(self, xy):

        class MockMark1(MockMark):
            def _legend_artist(self, variables, value, scales):
                return mpl.lines.Line2D([], [])

        class MockMark2(MockMark):
            def _legend_artist(self, variables, value, scales):
                return mpl.patches.Patch()

        s = pd.Series(["a", "b", "a", "c"], name="s")
        p = Plot(**xy, color=s).add(MockMark1()).add(MockMark2()).plot()

        legend, = p._figure.legends

        names = categorical_order(s)
        labels = [t.get_text() for t in legend.get_texts()]
        assert labels == names

        if not _version_predates(mpl, "3.5"):
            contents = legend.get_children()[0]
            assert len(contents.findobj(mpl.lines.Line2D)) == len(names)
            assert len(contents.findobj(mpl.patches.Patch)) == len(names)

    def test_three_layers(self, xy):

        class MockMarkLine(MockMark):
            def _legend_artist(self, variables, value, scales):
                return mpl.lines.Line2D([], [])

        s = pd.Series(["a", "b", "a", "c"], name="s")
        p = Plot(**xy, color=s)
        for _ in range(3):
            p = p.add(MockMarkLine())
        p = p.plot()
        texts = p._figure.legends[0].get_texts()
        assert len(texts) == len(s.unique())

    def test_identity_scale_ignored(self, xy):

        s = pd.Series(["r", "g", "b", "g"])
        p = Plot(**xy).add(MockMark(), color=s).scale(color=None).plot()
        assert not p._legend_contents

    def test_suppression_in_add_method(self, xy):

        s = pd.Series(["a", "b", "a", "c"], name="s")
        p = Plot(**xy).add(MockMark(), color=s, legend=False).plot()
        assert not p._legend_contents

    def test_anonymous_title(self, xy):

        p = Plot(**xy, color=["a", "b", "c", "d"]).add(MockMark()).plot()
        legend, = p._figure.legends
        assert legend.get_title().get_text() == ""

    def test_legendless_mark(self, xy):

        class NoLegendMark(MockMark):
            def _legend_artist(self, variables, value, scales):
                return None

        p = Plot(**xy, color=["a", "b", "c", "d"]).add(NoLegendMark()).plot()
        assert not p._figure.legends

    def test_legend_has_no_offset(self, xy):

        color = np.add(xy["x"], 1e8)
        p = Plot(**xy, color=color).add(MockMark()).plot()
        legend = p._figure.legends[0]
        assert legend.texts
        for text in legend.texts:
            assert float(text.get_text()) > 1e7

    def test_layer_legend(self, xy):

        p = Plot(**xy).add(MockMark(), label="a").add(MockMark(), label="b").plot()
        legend = p._figure.legends[0]
        assert legend.texts
        for text, expected in zip(legend.texts, "ab"):
            assert text.get_text() == expected

    def test_layer_legend_with_scale_legend(self, xy):

        s = pd.Series(["a", "b", "a", "c"], name="s")
        p = Plot(**xy, color=s).add(MockMark(), label="x").plot()

        legend = p._figure.legends[0]
        texts = [t.get_text() for t in legend.findobj(mpl.text.Text)]
        assert "x" in texts
        for val in s.unique():
            assert val in texts

    def test_layer_legend_title(self, xy):

        p = Plot(**xy).add(MockMark(), label="x").label(legend="layer").plot()
        assert p._figure.legends[0].get_title().get_text() == "layer"


class TestDefaultObject:

    def test_default_repr(self):

        assert repr(Default()) == ""


class TestThemeConfig:

    @pytest.fixture(autouse=True)
    def reset_config(self):
        yield
        Plot.config.theme.reset()

    def test_default(self):

        p = Plot().plot()
        ax = p._figure.axes[0]
        expected = Plot.config.theme["axes.facecolor"]
        assert mpl.colors.same_color(ax.get_facecolor(), expected)

    def test_setitem(self):

        color = "#CCC"
        Plot.config.theme["axes.facecolor"] = color
        p = Plot().plot()
        ax = p._figure.axes[0]
        assert mpl.colors.same_color(ax.get_facecolor(), color)

    def test_update(self):

        color = "#DDD"
        Plot.config.theme.update({"axes.facecolor": color})
        p = Plot().plot()
        ax = p._figure.axes[0]
        assert mpl.colors.same_color(ax.get_facecolor(), color)

    def test_reset(self):

        orig = Plot.config.theme["axes.facecolor"]
        Plot.config.theme.update({"axes.facecolor": "#EEE"})
        Plot.config.theme.reset()
        p = Plot().plot()
        ax = p._figure.axes[0]
        assert mpl.colors.same_color(ax.get_facecolor(), orig)

    def test_copy(self):

        key, val = "axes.facecolor", ".95"
        orig = Plot.config.theme[key]
        theme = Plot.config.theme.copy()
        theme.update({key: val})
        assert Plot.config.theme[key] == orig

    def test_html_repr(self):

        res = Plot.config.theme._repr_html_()
        for tag in ["div", "table", "tr", "td"]:
            assert res.count(f"<{tag}") == res.count(f"{key}:" in res


class TestDisplayConfig:

    @pytest.fixture(autouse=True)
    def reset_config(self):
        yield
        Plot.config.display.update(PlotConfig().display)

    def test_png_format(self):

        Plot.config.display["format"] = "png"

        assert Plot()._repr_svg_() is None
        assert Plot().plot()._repr_svg_() is None

        def assert_valid_png(p):
            data, metadata = p._repr_png_()
            img = Image.open(io.BytesIO(data))
            assert img.format == "PNG"
            assert sorted(metadata) == ["height", "width"]

        assert_valid_png(Plot())
        assert_valid_png(Plot().plot())

    def test_svg_format(self):

        Plot.config.display["format"] = "svg"

        assert Plot()._repr_png_() is None
        assert Plot().plot()._repr_png_() is None

        def assert_valid_svg(p):
            res = p._repr_svg_()
            root = xml.etree.ElementTree.fromstring(res)
            assert root.tag == "{http://www.w3.org/2000/svg}svg"

        assert_valid_svg(Plot())
        assert_valid_svg(Plot().plot())

    def test_png_scaling(self):

        Plot.config.display["scaling"] = 1.
        res1, meta1 = Plot()._repr_png_()

        Plot.config.display["scaling"] = .5
        res2, meta2 = Plot()._repr_png_()

        assert meta1["width"] / 2 == meta2["width"]
        assert meta1["height"] / 2 == meta2["height"]

        img1 = Image.open(io.BytesIO(res1))
        img2 = Image.open(io.BytesIO(res2))
        assert img1.size == img2.size

    def test_svg_scaling(self):

        Plot.config.display["format"] = "svg"

        Plot.config.display["scaling"] = 1.
        res1 = Plot()._repr_svg_()

        Plot.config.display["scaling"] = .5
        res2 = Plot()._repr_svg_()

        root1 = xml.etree.ElementTree.fromstring(res1)
        root2 = xml.etree.ElementTree.fromstring(res2)

        def getdim(root, dim):
            return float(root.attrib[dim][:-2])

        assert getdim(root1, "width") / 2 == getdim(root2, "width")
        assert getdim(root1, "height") / 2 == getdim(root2, "height")

    def test_png_hidpi(self):

        res1, meta1 = Plot()._repr_png_()

        Plot.config.display["hidpi"] = False
        res2, meta2 = Plot()._repr_png_()

        assert meta1["width"] == meta2["width"]
        assert meta1["height"] == meta2["height"]

        img1 = Image.open(io.BytesIO(res1))
        img2 = Image.open(io.BytesIO(res2))
        assert img1.size[0] // 2 == img2.size[0]
        assert img1.size[1] // 2 == img2.size[1]


================================================
FILE: tests/_core/test_properties.py
================================================

import numpy as np
import pandas as pd
import matplotlib as mpl
from matplotlib.colors import same_color, to_rgb, to_rgba
from matplotlib.markers import MarkerStyle

import pytest
from numpy.testing import assert_array_equal

from seaborn._core.rules import categorical_order
from seaborn._core.scales import Nominal, Continuous, Boolean
from seaborn._core.properties import (
    Alpha,
    Color,
    Coordinate,
    EdgeWidth,
    Fill,
    LineStyle,
    LineWidth,
    Marker,
    PointSize,
)
from seaborn._compat import get_colormap
from seaborn.palettes import color_palette


class DataFixtures:

    @pytest.fixture
    def num_vector(self, long_df):
        return long_df["s"]

    @pytest.fixture
    def num_order(self, num_vector):
        return categorical_order(num_vector)

    @pytest.fixture
    def cat_vector(self, long_df):
        return long_df["a"]

    @pytest.fixture
    def cat_order(self, cat_vector):
        return categorical_order(cat_vector)

    @pytest.fixture
    def dt_num_vector(self, long_df):
        return long_df["t"]

    @pytest.fixture
    def dt_cat_vector(self, long_df):
        return long_df["d"]

    @pytest.fixture
    def bool_vector(self, long_df):
        return long_df["x"] > 10

    @pytest.fixture
    def vectors(self, num_vector, cat_vector, bool_vector):
        return {"num": num_vector, "cat": cat_vector, "bool": bool_vector}


class TestCoordinate(DataFixtures):

    def test_bad_scale_arg_str(self, num_vector):

        err = "Unknown magic arg for x scale: 'xxx'."
        with pytest.raises(ValueError, match=err):
            Coordinate("x").infer_scale("xxx", num_vector)

    def test_bad_scale_arg_type(self, cat_vector):

        err = "Magic arg for x scale must be str, not list."
        with pytest.raises(TypeError, match=err):
            Coordinate("x").infer_scale([1, 2, 3], cat_vector)


class TestColor(DataFixtures):

    def assert_same_rgb(self, a, b):
        assert_array_equal(a[:, :3], b[:, :3])

    def test_nominal_default_palette(self, cat_vector, cat_order):

        m = Color().get_mapping(Nominal(), cat_vector)
        n = len(cat_order)
        actual = m(np.arange(n))
        expected = color_palette(None, n)
        for have, want in zip(actual, expected):
            assert same_color(have, want)

    def test_nominal_default_palette_large(self):

        vector = pd.Series(list("abcdefghijklmnopqrstuvwxyz"))
        m = Color().get_mapping(Nominal(), vector)
        actual = m(np.arange(26))
        expected = color_palette("husl", 26)
        for have, want in zip(actual, expected):
            assert same_color(have, want)

    def test_nominal_named_palette(self, cat_vector, cat_order):

        palette = "Blues"
        m = Color().get_mapping(Nominal(palette), cat_vector)
        n = len(cat_order)
        actual = m(np.arange(n))
        expected = color_palette(palette, n)
        for have, want in zip(actual, expected):
            assert same_color(have, want)

    def test_nominal_list_palette(self, cat_vector, cat_order):

        palette = color_palette("Reds", len(cat_order))
        m = Color().get_mapping(Nominal(palette), cat_vector)
        actual = m(np.arange(len(palette)))
        expected = palette
        for have, want in zip(actual, expected):
            assert same_color(have, want)

    def test_nominal_dict_palette(self, cat_vector, cat_order):

        colors = color_palette("Greens")
        palette = dict(zip(cat_order, colors))
        m = Color().get_mapping(Nominal(palette), cat_vector)
        n = len(cat_order)
        actual = m(np.arange(n))
        expected = colors
        for have, want in zip(actual, expected):
            assert same_color(have, want)

    def test_nominal_dict_with_missing_keys(self, cat_vector, cat_order):

        palette = dict(zip(cat_order[1:], color_palette("Purples")))
        with pytest.raises(ValueError, match="No entry in color dict"):
            Color("color").get_mapping(Nominal(palette), cat_vector)

    def test_nominal_list_too_short(self, cat_vector, cat_order):

        n = len(cat_order) - 1
        palette = color_palette("Oranges", n)
        msg = rf"The edgecolor list has fewer values \({n}\) than needed \({n + 1}\)"
        with pytest.warns(UserWarning, match=msg):
            Color("edgecolor").get_mapping(Nominal(palette), cat_vector)

    def test_nominal_list_too_long(self, cat_vector, cat_order):

        n = len(cat_order) + 1
        palette = color_palette("Oranges", n)
        msg = rf"The edgecolor list has more values \({n}\) than needed \({n - 1}\)"
        with pytest.warns(UserWarning, match=msg):
            Color("edgecolor").get_mapping(Nominal(palette), cat_vector)

    def test_continuous_default_palette(self, num_vector):

        cmap = color_palette("ch:", as_cmap=True)
        m = Color().get_mapping(Continuous(), num_vector)
        self.assert_same_rgb(m(num_vector), cmap(num_vector))

    def test_continuous_named_palette(self, num_vector):

        pal = "flare"
        cmap = color_palette(pal, as_cmap=True)
        m = Color().get_mapping(Continuous(pal), num_vector)
        self.assert_same_rgb(m(num_vector), cmap(num_vector))

    def test_continuous_tuple_palette(self, num_vector):

        vals = ("blue", "red")
        cmap = color_palette("blend:" + ",".join(vals), as_cmap=True)
        m = Color().get_mapping(Continuous(vals), num_vector)
        self.assert_same_rgb(m(num_vector), cmap(num_vector))

    def test_continuous_callable_palette(self, num_vector):

        cmap = get_colormap("viridis")
        m = Color().get_mapping(Continuous(cmap), num_vector)
        self.assert_same_rgb(m(num_vector), cmap(num_vector))

    def test_continuous_missing(self):

        x = pd.Series([1, 2, np.nan, 4])
        m = Color().get_mapping(Continuous(), x)
        assert np.isnan(m(x)[2]).all()

    def test_bad_scale_values_continuous(self, num_vector):

        with pytest.raises(TypeError, match="Scale values for color with a Continuous"):
            Color().get_mapping(Continuous(["r", "g", "b"]), num_vector)

    def test_bad_scale_values_nominal(self, cat_vector):

        with pytest.raises(TypeError, match="Scale values for color with a Nominal"):
            Color().get_mapping(Nominal(get_colormap("viridis")), cat_vector)

    def test_bad_inference_arg(self, cat_vector):

        with pytest.raises(TypeError, match="A single scale argument for color"):
            Color().infer_scale(123, cat_vector)

    @pytest.mark.parametrize(
        "data_type,scale_class",
        [("cat", Nominal), ("num", Continuous), ("bool", Boolean)]
    )
    def test_default(self, data_type, scale_class, vectors):

        scale = Color().default_scale(vectors[data_type])
        assert isinstance(scale, scale_class)

    def test_default_numeric_data_category_dtype(self, num_vector):

        scale = Color().default_scale(num_vector.astype("category"))
        assert isinstance(scale, Nominal)

    def test_default_binary_data(self):

        x = pd.Series([0, 0, 1, 0, 1], dtype=int)
        scale = Color().default_scale(x)
        assert isinstance(scale, Continuous)

    @pytest.mark.parametrize(
        "values,data_type,scale_class",
        [
            ("viridis", "cat", Nominal),  # Based on variable type
            ("viridis", "num", Continuous),  # Based on variable type
            ("viridis", "bool", Boolean),  # Based on variable type
            ("muted", "num", Nominal),  # Based on qualitative palette
            (["r", "g", "b"], "num", Nominal),  # Based on list palette
            ({2: "r", 4: "g", 8: "b"}, "num", Nominal),  # Based on dict palette
            (("r", "b"), "num", Continuous),  # Based on tuple / variable type
            (("g", "m"), "cat", Nominal),  # Based on tuple / variable type
            (("c", "y"), "bool", Boolean),  # Based on tuple / variable type
            (get_colormap("inferno"), "num", Continuous),  # Based on callable
        ]
    )
    def test_inference(self, values, data_type, scale_class, vectors):

        scale = Color().infer_scale(values, vectors[data_type])
        assert isinstance(scale, scale_class)
        assert scale.values == values

    def test_standardization(self):

        f = Color().standardize
        assert f("C3") == to_rgb("C3")
        assert f("dodgerblue") == to_rgb("dodgerblue")

        assert f((.1, .2, .3)) == (.1, .2, .3)
        assert f((.1, .2, .3, .4)) == (.1, .2, .3, .4)

        assert f("#123456") == to_rgb("#123456")
        assert f("#12345678") == to_rgba("#12345678")

        assert f("#123") == to_rgb("#123")
        assert f("#1234") == to_rgba("#1234")


class ObjectPropertyBase(DataFixtures):

    def assert_equal(self, a, b):

        assert self.unpack(a) == self.unpack(b)

    def unpack(self, x):
        return x

    @pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
    def test_default(self, data_type, vectors):

        scale = self.prop().default_scale(vectors[data_type])
        assert isinstance(scale, Boolean if data_type == "bool" else Nominal)

    @pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
    def test_inference_list(self, data_type, vectors):

        scale = self.prop().infer_scale(self.values, vectors[data_type])
        assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
        assert scale.values == self.values

    @pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
    def test_inference_dict(self, data_type, vectors):

        x = vectors[data_type]
        values = dict(zip(categorical_order(x), self.values))
        scale = self.prop().infer_scale(values, x)
        assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
        assert scale.values == values

    def test_dict_missing(self, cat_vector):

        levels = categorical_order(cat_vector)
        values = dict(zip(levels, self.values[:-1]))
        scale = Nominal(values)
        name = self.prop.__name__.lower()
        msg = f"No entry in {name} dictionary for {repr(levels[-1])}"
        with pytest.raises(ValueError, match=msg):
            self.prop().get_mapping(scale, cat_vector)

    @pytest.mark.parametrize("data_type", ["cat", "num"])
    def test_mapping_default(self, data_type, vectors):

        x = vectors[data_type]
        mapping = self.prop().get_mapping(Nominal(), x)
        n = x.nunique()
        for i, expected in enumerate(self.prop()._default_values(n)):
            actual, = mapping([i])
            self.assert_equal(actual, expected)

    @pytest.mark.parametrize("data_type", ["cat", "num"])
    def test_mapping_from_list(self, data_type, vectors):

        x = vectors[data_type]
        scale = Nominal(self.values)
        mapping = self.prop().get_mapping(scale, x)
        for i, expected in enumerate(self.standardized_values):
            actual, = mapping([i])
            self.assert_equal(actual, expected)

    @pytest.mark.parametrize("data_type", ["cat", "num"])
    def test_mapping_from_dict(self, data_type, vectors):

        x = vectors[data_type]
        levels = categorical_order(x)
        values = dict(zip(levels, self.values[::-1]))
        standardized_values = dict(zip(levels, self.standardized_values[::-1]))

        scale = Nominal(values)
        mapping = self.prop().get_mapping(scale, x)
        for i, level in enumerate(levels):
            actual, = mapping([i])
            expected = standardized_values[level]
            self.assert_equal(actual, expected)

    def test_mapping_with_null_value(self, cat_vector):

        mapping = self.prop().get_mapping(Nominal(self.values), cat_vector)
        actual = mapping(np.array([0, np.nan, 2]))
        v0, _, v2 = self.standardized_values
        expected = [v0, self.prop.null_value, v2]
        for a, b in zip(actual, expected):
            self.assert_equal(a, b)

    def test_unique_default_large_n(self):

        n = 24
        x = pd.Series(np.arange(n))
        mapping = self.prop().get_mapping(Nominal(), x)
        assert len({self.unpack(x_i) for x_i in mapping(x)}) == n

    def test_bad_scale_values(self, cat_vector):

        var_name = self.prop.__name__.lower()
        with pytest.raises(TypeError, match=f"Scale values for a {var_name} variable"):
            self.prop().get_mapping(Nominal(("o", "s")), cat_vector)


class TestMarker(ObjectPropertyBase):

    prop = Marker
    values = ["o", (5, 2, 0), MarkerStyle("^")]
    standardized_values = [MarkerStyle(x) for x in values]

    def assert_equal(self, a, b):
        a_path, b_path = a.get_path(), b.get_path()
        assert_array_equal(a_path.vertices, b_path.vertices)
        assert_array_equal(a_path.codes, b_path.codes)
        assert a_path.simplify_threshold == b_path.simplify_threshold
        assert a_path.should_simplify == b_path.should_simplify

        assert a.get_joinstyle() == b.get_joinstyle()
        assert a.get_transform().to_values() == b.get_transform().to_values()
        assert a.get_fillstyle() == b.get_fillstyle()

    def unpack(self, x):
        return (
            x.get_path(),
            x.get_joinstyle(),
            x.get_transform().to_values(),
            x.get_fillstyle(),
        )


class TestLineStyle(ObjectPropertyBase):

    prop = LineStyle
    values = ["solid", "--", (1, .5)]
    standardized_values = [LineStyle._get_dash_pattern(x) for x in values]

    def test_bad_type(self):

        p = LineStyle()
        with pytest.raises(TypeError, match="^Linestyle must be .+, not list.$"):
            p.standardize([1, 2])

    def test_bad_style(self):

        p = LineStyle()
        with pytest.raises(ValueError, match="^Linestyle string must be .+, not 'o'.$"):
            p.standardize("o")

    def test_bad_dashes(self):

        p = LineStyle()
        with pytest.raises(TypeError, match="^Invalid dash pattern"):
            p.standardize((1, 2, "x"))


class TestFill(DataFixtures):

    @pytest.fixture
    def vectors(self):

        return {
            "cat": pd.Series(["a", "a", "b"]),
            "num": pd.Series([1, 1, 2]),
            "bool": pd.Series([True, True, False])
        }

    @pytest.fixture
    def cat_vector(self, vectors):
        return vectors["cat"]

    @pytest.fixture
    def num_vector(self, vectors):
        return vectors["num"]

    @pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
    def test_default(self, data_type, vectors):

        x = vectors[data_type]
        scale = Fill().default_scale(x)
        assert isinstance(scale, Boolean if data_type == "bool" else Nominal)

    @pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
    def test_inference_list(self, data_type, vectors):

        x = vectors[data_type]
        scale = Fill().infer_scale([True, False], x)
        assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
        assert scale.values == [True, False]

    @pytest.mark.parametrize("data_type", ["cat", "num", "bool"])
    def test_inference_dict(self, data_type, vectors):

        x = vectors[data_type]
        values = dict(zip(x.unique(), [True, False]))
        scale = Fill().infer_scale(values, x)
        assert isinstance(scale, Boolean if data_type == "bool" else Nominal)
        assert scale.values == values

    def test_mapping_categorical_data(self, cat_vector):

        mapping = Fill().get_mapping(Nominal(), cat_vector)
        assert_array_equal(mapping([0, 1, 0]), [True, False, True])

    def test_mapping_numeric_data(self, num_vector):

        mapping = Fill().get_mapping(Nominal(), num_vector)
        assert_array_equal(mapping([0, 1, 0]), [True, False, True])

    def test_mapping_list(self, cat_vector):

        mapping = Fill().get_mapping(Nominal([False, True]), cat_vector)
        assert_array_equal(mapping([0, 1, 0]), [False, True, False])

    def test_mapping_truthy_list(self, cat_vector):

        mapping = Fill().get_mapping(Nominal([0, 1]), cat_vector)
        assert_array_equal(mapping([0, 1, 0]), [False, True, False])

    def test_mapping_dict(self, cat_vector):

        values = dict(zip(cat_vector.unique(), [False, True]))
        mapping = Fill().get_mapping(Nominal(values), cat_vector)
        assert_array_equal(mapping([0, 1, 0]), [False, True, False])

    def test_cycle_warning(self):

        x = pd.Series(["a", "b", "c"])
        with pytest.warns(UserWarning, match="The variable assigned to fill"):
            Fill().get_mapping(Nominal(), x)

    def test_values_error(self):

        x = pd.Series(["a", "b"])
        with pytest.raises(TypeError, match="Scale values for fill must be"):
            Fill().get_mapping(Nominal("bad_values"), x)


class IntervalBase(DataFixtures):

    def norm(self, x):
        return (x - x.min()) / (x.max() - x.min())

    @pytest.mark.parametrize("data_type,scale_class", [
        ("cat", Nominal),
        ("num", Continuous),
        ("bool", Boolean),
    ])
    def test_default(self, data_type, scale_class, vectors):

        x = vectors[data_type]
        scale = self.prop().default_scale(x)
        assert isinstance(scale, scale_class)

    @pytest.mark.parametrize("arg,data_type,scale_class", [
        ((1, 3), "cat", Nominal),
        ((1, 3), "num", Continuous),
        ((1, 3), "bool", Boolean),
        ([1, 2, 3], "cat", Nominal),
        ([1, 2, 3], "num", Nominal),
        ([1, 3], "bool", Boolean),
        ({"a": 1, "b": 3, "c": 2}, "cat", Nominal),
        ({2: 1, 4: 3, 8: 2}, "num", Nominal),
        ({True: 4, False: 2}, "bool", Boolean),
    ])
    def test_inference(self, arg, data_type, scale_class, vectors):

        x = vectors[data_type]
        scale = self.prop().infer_scale(arg, x)
        assert isinstance(scale, scale_class)
        assert scale.values == arg

    def test_mapped_interval_numeric(self, num_vector):

        mapping = self.prop().get_mapping(Continuous(), num_vector)
        assert_array_equal(mapping([0, 1]), self.prop().default_range)

    def test_mapped_interval_categorical(self, cat_vector):

        mapping = self.prop().get_mapping(Nominal(), cat_vector)
        n = cat_vector.nunique()
        assert_array_equal(mapping([n - 1, 0]), self.prop().default_range)

    def test_bad_scale_values_numeric_data(self, num_vector):

        prop_name = self.prop.__name__.lower()
        err_stem = (
            f"Values for {prop_name} variables with Continuous scale must be 2-tuple"
        )

        with pytest.raises(TypeError, match=f"{err_stem}; not ."):
            self.prop().get_mapping(Continuous("abc"), num_vector)

        with pytest.raises(TypeError, match=f"{err_stem}; not 3-tuple."):
            self.prop().get_mapping(Continuous((1, 2, 3)), num_vector)

    def test_bad_scale_values_categorical_data(self, cat_vector):

        prop_name = self.prop.__name__.lower()
        err_text = f"Values for {prop_name} variables with Nominal scale"
        with pytest.raises(TypeError, match=err_text):
            self.prop().get_mapping(Nominal("abc"), cat_vector)


class TestAlpha(IntervalBase):
    prop = Alpha


class TestLineWidth(IntervalBase):
    prop = LineWidth

    def test_rcparam_default(self):

        with mpl.rc_context({"lines.linewidth": 2}):
            assert self.prop().default_range == (1, 4)


class TestEdgeWidth(IntervalBase):
    prop = EdgeWidth

    def test_rcparam_default(self):

        with mpl.rc_context({"patch.linewidth": 2}):
            assert self.prop().default_range == (1, 4)


class TestPointSize(IntervalBase):
    prop = PointSize

    def test_areal_scaling_numeric(self, num_vector):

        limits = 5, 10
        scale = Continuous(limits)
        mapping = self.prop().get_mapping(scale, num_vector)
        x = np.linspace(0, 1, 6)
        expected = np.sqrt(np.linspace(*np.square(limits), num=len(x)))
        assert_array_equal(mapping(x), expected)

    def test_areal_scaling_categorical(self, cat_vector):

        limits = (2, 4)
        scale = Nominal(limits)
        mapping = self.prop().get_mapping(scale, cat_vector)
        assert_array_equal(mapping(np.arange(3)), [4, np.sqrt(10), 2])


================================================
FILE: tests/_core/test_rules.py
================================================

import numpy as np
import pandas as pd

import pytest

from seaborn._core.rules import (
    VarType,
    variable_type,
    categorical_order,
)


def test_vartype_object():

    v = VarType("numeric")
    assert v == "numeric"
    assert v != "categorical"
    with pytest.raises(AssertionError):
        v == "number"
    with pytest.raises(AssertionError):
        VarType("date")


def test_variable_type():

    s = pd.Series([1., 2., 3.])
    assert variable_type(s) == "numeric"
    assert variable_type(s.astype(int)) == "numeric"
    assert variable_type(s.astype(object)) == "numeric"

    s = pd.Series([1, 2, 3, np.nan], dtype=object)
    assert variable_type(s) == "numeric"

    s = pd.Series([np.nan, np.nan])
    assert variable_type(s) == "numeric"

    s = pd.Series([pd.NA, pd.NA])
    assert variable_type(s) == "numeric"

    s = pd.Series([1, 2, pd.NA], dtype="Int64")
    assert variable_type(s) == "numeric"

    s = pd.Series([1, 2, pd.NA], dtype=object)
    assert variable_type(s) == "numeric"

    s = pd.Series(["1", "2", "3"])
    assert variable_type(s) == "categorical"

    s = pd.Series([True, False, False])
    assert variable_type(s) == "numeric"
    assert variable_type(s, boolean_type="categorical") == "categorical"
    assert variable_type(s, boolean_type="boolean") == "boolean"

    # This should arguably be datmetime, but we don't currently handle it correctly
    # Test is mainly asserting that this doesn't fail on the boolean check.
    s = pd.timedelta_range(1, periods=3, freq="D").to_series()
    assert variable_type(s) == "categorical"

    s_cat = s.astype("category")
    assert variable_type(s_cat, boolean_type="categorical") == "categorical"
    assert variable_type(s_cat, boolean_type="numeric") == "categorical"
    assert variable_type(s_cat, boolean_type="boolean") == "categorical"

    s = pd.Series([1, 0, 0])
    assert variable_type(s, boolean_type="boolean") == "boolean"
    assert variable_type(s, boolean_type="boolean", strict_boolean=True) == "numeric"

    s = pd.Series([1, 0, 0])
    assert variable_type(s, boolean_type="boolean") == "boolean"

    s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)])
    assert variable_type(s) == "datetime"
    assert variable_type(s.astype(object)) == "datetime"


def test_categorical_order():

    x = pd.Series(["a", "c", "c", "b", "a", "d"])
    y = pd.Series([3, 2, 5, 1, 4])
    order = ["a", "b", "c", "d"]

    out = categorical_order(x)
    assert out == ["a", "c", "b", "d"]

    out = categorical_order(x, order)
    assert out == order

    out = categorical_order(x, ["b", "a"])
    assert out == ["b", "a"]

    out = categorical_order(y)
    assert out == [1, 2, 3, 4, 5]

    out = categorical_order(pd.Series(y))
    assert out == [1, 2, 3, 4, 5]

    y_cat = pd.Series(pd.Categorical(y, y))
    out = categorical_order(y_cat)
    assert out == list(y)

    x = pd.Series(x).astype("category")
    out = categorical_order(x)
    assert out == list(x.cat.categories)

    out = categorical_order(x, ["b", "a"])
    assert out == ["b", "a"]

    x = pd.Series(["a", np.nan, "c", "c", "b", "a", "d"])
    out = categorical_order(x)
    assert out == ["a", "c", "b", "d"]


================================================
FILE: tests/_core/test_scales.py
================================================
import re

import numpy as np
import pandas as pd
import matplotlib as mpl

import pytest
from numpy.testing import assert_array_equal
from pandas.testing import assert_series_equal

from seaborn._core.plot import Plot
from seaborn._core.scales import (
    Nominal,
    Continuous,
    Boolean,
    Temporal,
    PseudoAxis,
)
from seaborn._core.properties import (
    IntervalProperty,
    ObjectProperty,
    Coordinate,
    Alpha,
    Color,
    Fill,
)
from seaborn.palettes import color_palette
from seaborn.utils import _version_predates


class TestContinuous:

    @pytest.fixture
    def x(self):
        return pd.Series([1, 3, 9], name="x", dtype=float)

    def setup_ticks(self, x, *args, **kwargs):

        s = Continuous().tick(*args, **kwargs)._setup(x, Coordinate())
        a = PseudoAxis(s._matplotlib_scale)
        a.set_view_interval(0, 1)
        return a

    def setup_labels(self, x, *args, **kwargs):

        s = Continuous().label(*args, **kwargs)._setup(x, Coordinate())
        a = PseudoAxis(s._matplotlib_scale)
        a.set_view_interval(0, 1)
        locs = a.major.locator()
        return a, locs

    def test_coordinate_defaults(self, x):

        s = Continuous()._setup(x, Coordinate())
        assert_series_equal(s(x), x)

    def test_coordinate_transform(self, x):

        s = Continuous(trans="log")._setup(x, Coordinate())
        assert_series_equal(s(x), np.log10(x))

    def test_coordinate_transform_with_parameter(self, x):

        s = Continuous(trans="pow3")._setup(x, Coordinate())
        assert_series_equal(s(x), np.power(x, 3))

    def test_coordinate_transform_error(self, x):

        s = Continuous(trans="bad")
        with pytest.raises(ValueError, match="Unknown value provided"):
            s._setup(x, Coordinate())

    def test_interval_defaults(self, x):

        s = Continuous()._setup(x, IntervalProperty())
        assert_array_equal(s(x), [0, .25, 1])

    def test_interval_with_range(self, x):

        s = Continuous((1, 3))._setup(x, IntervalProperty())
        assert_array_equal(s(x), [1, 1.5, 3])

    def test_interval_with_norm(self, x):

        s = Continuous(norm=(3, 7))._setup(x, IntervalProperty())
        assert_array_equal(s(x), [-.5, 0, 1.5])

    def test_interval_with_range_norm_and_transform(self, x):

        x = pd.Series([1, 10, 100])
        # TODO param order?
        s = Continuous((2, 3), (10, 100), "log")._setup(x, IntervalProperty())
        assert_array_equal(s(x), [1, 2, 3])

    def test_interval_with_bools(self):

        x = pd.Series([True, False, False])
        s = Continuous()._setup(x, IntervalProperty())
        assert_array_equal(s(x), [1, 0, 0])

    def test_color_defaults(self, x):

        cmap = color_palette("ch:", as_cmap=True)
        s = Continuous()._setup(x, Color())
        assert_array_equal(s(x), cmap([0, .25, 1])[:, :3])  # FIXME RGBA

    def test_color_named_values(self, x):

        cmap = color_palette("viridis", as_cmap=True)
        s = Continuous("viridis")._setup(x, Color())
        assert_array_equal(s(x), cmap([0, .25, 1])[:, :3])  # FIXME RGBA

    def test_color_tuple_values(self, x):

        cmap = color_palette("blend:b,g", as_cmap=True)
        s = Continuous(("b", "g"))._setup(x, Color())
        assert_array_equal(s(x), cmap([0, .25, 1])[:, :3])  # FIXME RGBA

    def test_color_callable_values(self, x):

        cmap = color_palette("light:r", as_cmap=True)
        s = Continuous(cmap)._setup(x, Color())
        assert_array_equal(s(x), cmap([0, .25, 1])[:, :3])  # FIXME RGBA

    def test_color_with_norm(self, x):

        cmap = color_palette("ch:", as_cmap=True)
        s = Continuous(norm=(3, 7))._setup(x, Color())
        assert_array_equal(s(x), cmap([-.5, 0, 1.5])[:, :3])  # FIXME RGBA

    def test_color_with_transform(self, x):

        x = pd.Series([1, 10, 100], name="x", dtype=float)
        cmap = color_palette("ch:", as_cmap=True)
        s = Continuous(trans="log")._setup(x, Color())
        assert_array_equal(s(x), cmap([0, .5, 1])[:, :3])  # FIXME RGBA

    def test_tick_locator(self, x):

        locs = [.2, .5, .8]
        locator = mpl.ticker.FixedLocator(locs)
        a = self.setup_ticks(x, locator)
        assert_array_equal(a.major.locator(), locs)

    def test_tick_locator_input_check(self, x):

        err = "Tick locator must be an instance of .*?, not ."
        with pytest.raises(TypeError, match=err):
            Continuous().tick((1, 2))

    def test_tick_upto(self, x):

        for n in [2, 5, 10]:
            a = self.setup_ticks(x, upto=n)
            assert len(a.major.locator()) <= (n + 1)

    def test_tick_every(self, x):

        for d in [.05, .2, .5]:
            a = self.setup_ticks(x, every=d)
            assert np.allclose(np.diff(a.major.locator()), d)

    def test_tick_every_between(self, x):

        lo, hi = .2, .8
        for d in [.05, .2, .5]:
            a = self.setup_ticks(x, every=d, between=(lo, hi))
            expected = np.arange(lo, hi + d, d)
            assert_array_equal(a.major.locator(), expected)

    def test_tick_at(self, x):

        locs = [.2, .5, .9]
        a = self.setup_ticks(x, at=locs)
        assert_array_equal(a.major.locator(), locs)

    def test_tick_count(self, x):

        n = 8
        a = self.setup_ticks(x, count=n)
        assert_array_equal(a.major.locator(), np.linspace(0, 1, n))

    def test_tick_count_between(self, x):

        n = 5
        lo, hi = .2, .7
        a = self.setup_ticks(x, count=n, between=(lo, hi))
        assert_array_equal(a.major.locator(), np.linspace(lo, hi, n))

    def test_tick_minor(self, x):

        n = 3
        a = self.setup_ticks(x, count=2, minor=n)
        expected = np.linspace(0, 1, n + 2)
        if _version_predates(mpl, "3.8.0rc1"):
            # I am not sure why matplotlib <3.8  minor ticks include the
            # largest major location but exclude the smalllest one ...
            expected = expected[1:]
        assert_array_equal(a.minor.locator(), expected)

    def test_log_tick_default(self, x):

        s = Continuous(trans="log")._setup(x, Coordinate())
        a = PseudoAxis(s._matplotlib_scale)
        a.set_view_interval(.5, 1050)
        ticks = a.major.locator()
        assert np.allclose(np.diff(np.log10(ticks)), 1)

    def test_log_tick_upto(self, x):

        n = 3
        s = Continuous(trans="log").tick(upto=n)._setup(x, Coordinate())
        a = PseudoAxis(s._matplotlib_scale)
        assert a.major.locator.numticks == n

    def test_log_tick_count(self, x):

        with pytest.raises(RuntimeError, match="`count` requires"):
            Continuous(trans="log").tick(count=4)

        s = Continuous(trans="log").tick(count=4, between=(1, 1000))
        a = PseudoAxis(s._setup(x, Coordinate())._matplotlib_scale)
        a.set_view_interval(.5, 1050)
        assert_array_equal(a.major.locator(), [1, 10, 100, 1000])

    def test_log_tick_format_disabled(self, x):

        s = Continuous(trans="log").label(base=None)._setup(x, Coordinate())
        a = PseudoAxis(s._matplotlib_scale)
        a.set_view_interval(20, 20000)
        labels = a.major.formatter.format_ticks(a.major.locator())
        for text in labels:
            assert re.match(r"^\d+$", text)

    def test_log_tick_every(self, x):

        with pytest.raises(RuntimeError, match="`every` not supported"):
            Continuous(trans="log").tick(every=2)

    def test_symlog_tick_default(self, x):

        s = Continuous(trans="symlog")._setup(x, Coordinate())
        a = PseudoAxis(s._matplotlib_scale)
        a.set_view_interval(-1050, 1050)
        ticks = a.major.locator()
        assert ticks[0] == -ticks[-1]
        pos_ticks = np.sort(np.unique(np.abs(ticks)))
        assert np.allclose(np.diff(np.log10(pos_ticks[1:])), 1)
        assert pos_ticks[0] == 0

    def test_label_formatter(self, x):

        fmt = mpl.ticker.FormatStrFormatter("%.3f")
        a, locs = self.setup_labels(x, fmt)
        labels = a.major.formatter.format_ticks(locs)
        for text in labels:
            assert re.match(r"^\d\.\d{3}$", text)

    def test_label_like_pattern(self, x):

        a, locs = self.setup_labels(x, like=".4f")
        labels = a.major.formatter.format_ticks(locs)
        for text in labels:
            assert re.match(r"^\d\.\d{4}$", text)

    def test_label_like_string(self, x):

        a, locs = self.setup_labels(x, like="x = {x:.1f}")
        labels = a.major.formatter.format_ticks(locs)
        for text in labels:
            assert re.match(r"^x = \d\.\d$", text)

    def test_label_like_function(self, x):

        a, locs = self.setup_labels(x, like="{:^5.1f}".format)
        labels = a.major.formatter.format_ticks(locs)
        for text in labels:
            assert re.match(r"^ \d\.\d $", text)

    def test_label_base(self, x):

        a, locs = self.setup_labels(100 * x, base=2)
        labels = a.major.formatter.format_ticks(locs)
        for text in labels[1:]:
            assert not text or "2^" in text

    def test_label_unit(self, x):

        a, locs = self.setup_labels(1000 * x, unit="g")
        labels = a.major.formatter.format_ticks(locs)
        for text in labels[1:-1]:
            assert re.match(r"^\d+ mg$", text)

    def test_label_unit_with_sep(self, x):

        a, locs = self.setup_labels(1000 * x, unit=("", "g"))
        labels = a.major.formatter.format_ticks(locs)
        for text in labels[1:-1]:
            assert re.match(r"^\d+mg$", text)

    def test_label_empty_unit(self, x):

        a, locs = self.setup_labels(1000 * x, unit="")
        labels = a.major.formatter.format_ticks(locs)
        for text in labels[1:-1]:
            assert re.match(r"^\d+m$", text)

    def test_label_base_from_transform(self, x):

        s = Continuous(trans="log")
        a = PseudoAxis(s._setup(x, Coordinate())._matplotlib_scale)
        a.set_view_interval(10, 1000)
        label, = a.major.formatter.format_ticks([100])
        assert r"10^{2}" in label

    def test_label_type_checks(self):

        s = Continuous()
        with pytest.raises(TypeError, match="Label formatter must be"):
            s.label("{x}")

        with pytest.raises(TypeError, match="`like` must be"):
            s.label(like=2)


class TestNominal:

    @pytest.fixture
    def x(self):
        return pd.Series(["a", "c", "b", "c"], name="x")

    @pytest.fixture
    def y(self):
        return pd.Series([1, -1.5, 3, -1.5], name="y")

    def test_coordinate_defaults(self, x):

        s = Nominal()._setup(x, Coordinate())
        assert_array_equal(s(x), np.array([0, 1, 2, 1], float))

    def test_coordinate_with_order(self, x):

        s = Nominal(order=["a", "b", "c"])._setup(x, Coordinate())
        assert_array_equal(s(x), np.array([0, 2, 1, 2], float))

    def test_coordinate_with_subset_order(self, x):

        s = Nominal(order=["c", "a"])._setup(x, Coordinate())
        assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float))

    def test_coordinate_axis(self, x):

        ax = mpl.figure.Figure().subplots()
        s = Nominal()._setup(x, Coordinate(), ax.xaxis)
        assert_array_equal(s(x), np.array([0, 1, 2, 1], float))
        f = ax.xaxis.get_major_formatter()
        assert f.format_ticks([0, 1, 2]) == ["a", "c", "b"]

    def test_coordinate_axis_with_order(self, x):

        order = ["a", "b", "c"]
        ax = mpl.figure.Figure().subplots()
        s = Nominal(order=order)._setup(x, Coordinate(), ax.xaxis)
        assert_array_equal(s(x), np.array([0, 2, 1, 2], float))
        f = ax.xaxis.get_major_formatter()
        assert f.format_ticks([0, 1, 2]) == order

    def test_coordinate_axis_with_subset_order(self, x):

        order = ["c", "a"]
        ax = mpl.figure.Figure().subplots()
        s = Nominal(order=order)._setup(x, Coordinate(), ax.xaxis)
        assert_array_equal(s(x), np.array([1, 0, np.nan, 0], float))
        f = ax.xaxis.get_major_formatter()
        assert f.format_ticks([0, 1, 2]) == [*order, ""]

    def test_coordinate_axis_with_category_dtype(self, x):

        order = ["b", "a", "d", "c"]
        x = x.astype(pd.CategoricalDtype(order))
        ax = mpl.figure.Figure().subplots()
        s = Nominal()._setup(x, Coordinate(), ax.xaxis)
        assert_array_equal(s(x), np.array([1, 3, 0, 3], float))
        f = ax.xaxis.get_major_formatter()
        assert f.format_ticks([0, 1, 2, 3]) == order

    def test_coordinate_numeric_data(self, y):

        ax = mpl.figure.Figure().subplots()
        s = Nominal()._setup(y, Coordinate(), ax.yaxis)
        assert_array_equal(s(y), np.array([1, 0, 2, 0], float))
        f = ax.yaxis.get_major_formatter()
        assert f.format_ticks([0, 1, 2]) == ["-1.5", "1.0", "3.0"]

    def test_coordinate_numeric_data_with_order(self, y):

        order = [1, 4, -1.5]
        ax = mpl.figure.Figure().subplots()
        s = Nominal(order=order)._setup(y, Coordinate(), ax.yaxis)
        assert_array_equal(s(y), np.array([0, 2, np.nan, 2], float))
        f = ax.yaxis.get_major_formatter()
        assert f.format_ticks([0, 1, 2]) == ["1.0", "4.0", "-1.5"]

    def test_color_defaults(self, x):

        s = Nominal()._setup(x, Color())
        cs = color_palette()
        assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])

    def test_color_named_palette(self, x):

        pal = "flare"
        s = Nominal(pal)._setup(x, Color())
        cs = color_palette(pal, 3)
        assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])

    def test_color_list_palette(self, x):

        cs = color_palette("crest", 3)
        s = Nominal(cs)._setup(x, Color())
        assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])

    def test_color_dict_palette(self, x):

        cs = color_palette("crest", 3)
        pal = dict(zip("bac", cs))
        s = Nominal(pal)._setup(x, Color())
        assert_array_equal(s(x), [cs[1], cs[2], cs[0], cs[2]])

    def test_color_numeric_data(self, y):

        s = Nominal()._setup(y, Color())
        cs = color_palette()
        assert_array_equal(s(y), [cs[1], cs[0], cs[2], cs[0]])

    def test_color_numeric_with_order_subset(self, y):

        s = Nominal(order=[-1.5, 1])._setup(y, Color())
        c1, c2 = color_palette(n_colors=2)
        null = (np.nan, np.nan, np.nan)
        assert_array_equal(s(y), [c2, c1, null, c1])

    @pytest.mark.xfail(reason="Need to sort out float/int order")
    def test_color_numeric_int_float_mix(self):

        z = pd.Series([1, 2], name="z")
        s = Nominal(order=[1.0, 2])._setup(z, Color())
        c1, c2 = color_palette(n_colors=2)
        null = (np.nan, np.nan, np.nan)
        assert_array_equal(s(z), [c1, null, c2])

    def test_color_alpha_in_palette(self, x):

        cs = [(.2, .2, .3, .5), (.1, .2, .3, 1), (.5, .6, .2, 0)]
        s = Nominal(cs)._setup(x, Color())
        assert_array_equal(s(x), [cs[0], cs[1], cs[2], cs[1]])

    def test_color_unknown_palette(self, x):

        pal = "not_a_palette"
        err = f"'{pal}' is not a valid palette name"
        with pytest.raises(ValueError, match=err):
            Nominal(pal)._setup(x, Color())

    def test_object_defaults(self, x):

        class MockProperty(ObjectProperty):
            def _default_values(self, n):
                return list("xyz"[:n])

        s = Nominal()._setup(x, MockProperty())
        assert s(x) == ["x", "y", "z", "y"]

    def test_object_list(self, x):

        vs = ["x", "y", "z"]
        s = Nominal(vs)._setup(x, ObjectProperty())
        assert s(x) == ["x", "y", "z", "y"]

    def test_object_dict(self, x):

        vs = {"a": "x", "b": "y", "c": "z"}
        s = Nominal(vs)._setup(x, ObjectProperty())
        assert s(x) == ["x", "z", "y", "z"]

    def test_object_order(self, x):

        vs = ["x", "y", "z"]
        s = Nominal(vs, order=["c", "a", "b"])._setup(x, ObjectProperty())
        assert s(x) == ["y", "x", "z", "x"]

    def test_object_order_subset(self, x):

        vs = ["x", "y"]
        s = Nominal(vs, order=["a", "c"])._setup(x, ObjectProperty())
        assert s(x) == ["x", "y", None, "y"]

    def test_objects_that_are_weird(self, x):

        vs = [("x", 1), (None, None, 0), {}]
        s = Nominal(vs)._setup(x, ObjectProperty())
        assert s(x) == [vs[0], vs[1], vs[2], vs[1]]

    def test_alpha_default(self, x):

        s = Nominal()._setup(x, Alpha())
        assert_array_equal(s(x), [.95, .625, .3, .625])

    def test_fill(self):

        x = pd.Series(["a", "a", "b", "a"], name="x")
        s = Nominal()._setup(x, Fill())
        assert_array_equal(s(x), [True, True, False, True])

    def test_fill_dict(self):

        x = pd.Series(["a", "a", "b", "a"], name="x")
        vs = {"a": False, "b": True}
        s = Nominal(vs)._setup(x, Fill())
        assert_array_equal(s(x), [False, False, True, False])

    def test_fill_nunique_warning(self):

        x = pd.Series(["a", "b", "c", "a", "b"], name="x")
        with pytest.warns(UserWarning, match="The variable assigned to fill"):
            s = Nominal()._setup(x, Fill())
        assert_array_equal(s(x), [True, False, True, True, False])

    def test_interval_defaults(self, x):

        class MockProperty(IntervalProperty):
            _default_range = (1, 2)

        s = Nominal()._setup(x, MockProperty())
        assert_array_equal(s(x), [2, 1.5, 1, 1.5])

    def test_interval_tuple(self, x):

        s = Nominal((1, 2))._setup(x, IntervalProperty())
        assert_array_equal(s(x), [2, 1.5, 1, 1.5])

    def test_interval_tuple_numeric(self, y):

        s = Nominal((1, 2))._setup(y, IntervalProperty())
        assert_array_equal(s(y), [1.5, 2, 1, 2])

    def test_interval_list(self, x):

        vs = [2, 5, 4]
        s = Nominal(vs)._setup(x, IntervalProperty())
        assert_array_equal(s(x), [2, 5, 4, 5])

    def test_interval_dict(self, x):

        vs = {"a": 3, "b": 4, "c": 6}
        s = Nominal(vs)._setup(x, IntervalProperty())
        assert_array_equal(s(x), [3, 6, 4, 6])

    def test_interval_with_transform(self, x):

        class MockProperty(IntervalProperty):
            _forward = np.square
            _inverse = np.sqrt

        s = Nominal((2, 4))._setup(x, MockProperty())
        assert_array_equal(s(x), [4, np.sqrt(10), 2, np.sqrt(10)])

    def test_empty_data(self):

        x = pd.Series([], dtype=object, name="x")
        s = Nominal()._setup(x, Coordinate())
        assert_array_equal(s(x), [])

    def test_finalize(self, x):

        ax = mpl.figure.Figure().subplots()
        s = Nominal()._setup(x, Coordinate(), ax.yaxis)
        s._finalize(Plot(), ax.yaxis)

        levels = x.unique()
        assert ax.get_ylim() == (len(levels) - .5, -.5)
        assert_array_equal(ax.get_yticks(), list(range(len(levels))))
        for i, expected in enumerate(levels):
            assert ax.yaxis.major.formatter(i) == expected


class TestTemporal:

    @pytest.fixture
    def t(self):
        dates = pd.to_datetime(["1972-09-27", "1975-06-24", "1980-12-14"])
        return pd.Series(dates, name="x")

    @pytest.fixture
    def x(self, t):
        return pd.Series(mpl.dates.date2num(t), name=t.name)

    def test_coordinate_defaults(self, t, x):

        s = Temporal()._setup(t, Coordinate())
        assert_array_equal(s(t), x)

    def test_interval_defaults(self, t, x):

        s = Temporal()._setup(t, IntervalProperty())
        normed = (x - x.min()) / (x.max() - x.min())
        assert_array_equal(s(t), normed)

    def test_interval_with_range(self, t, x):

        values = (1, 3)
        s = Temporal((1, 3))._setup(t, IntervalProperty())
        normed = (x - x.min()) / (x.max() - x.min())
        expected = normed * (values[1] - values[0]) + values[0]
        assert_array_equal(s(t), expected)

    def test_interval_with_norm(self, t, x):

        norm = t[1], t[2]
        s = Temporal(norm=norm)._setup(t, IntervalProperty())
        n = mpl.dates.date2num(norm)
        normed = (x - n[0]) / (n[1] - n[0])
        assert_array_equal(s(t), normed)

    def test_color_defaults(self, t, x):

        cmap = color_palette("ch:", as_cmap=True)
        s = Temporal()._setup(t, Color())
        normed = (x - x.min()) / (x.max() - x.min())
        assert_array_equal(s(t), cmap(normed)[:, :3])  # FIXME RGBA

    def test_color_named_values(self, t, x):

        name = "viridis"
        cmap = color_palette(name, as_cmap=True)
        s = Temporal(name)._setup(t, Color())
        normed = (x - x.min()) / (x.max() - x.min())
        assert_array_equal(s(t), cmap(normed)[:, :3])  # FIXME RGBA

    def test_coordinate_axis(self, t, x):

        ax = mpl.figure.Figure().subplots()
        s = Temporal()._setup(t, Coordinate(), ax.xaxis)
        assert_array_equal(s(t), x)
        locator = ax.xaxis.get_major_locator()
        formatter = ax.xaxis.get_major_formatter()
        assert isinstance(locator, mpl.dates.AutoDateLocator)
        assert isinstance(formatter, mpl.dates.AutoDateFormatter)

    def test_tick_locator(self, t):

        locator = mpl.dates.YearLocator(month=3, day=15)
        s = Temporal().tick(locator)
        a = PseudoAxis(s._setup(t, Coordinate())._matplotlib_scale)
        a.set_view_interval(0, 365)
        assert 73 in a.major.locator()

    def test_tick_upto(self, t, x):

        n = 8
        ax = mpl.figure.Figure().subplots()
        Temporal().tick(upto=n)._setup(t, Coordinate(), ax.xaxis)
        locator = ax.xaxis.get_major_locator()
        assert set(locator.maxticks.values()) == {n}

    def test_label_formatter(self, t):

        formatter = mpl.dates.DateFormatter("%Y")
        s = Temporal().label(formatter)
        a = PseudoAxis(s._setup(t, Coordinate())._matplotlib_scale)
        a.set_view_interval(10, 1000)
        label, = a.major.formatter.format_ticks([100])
        assert label == "1970"

    def test_label_concise(self, t, x):

        ax = mpl.figure.Figure().subplots()
        Temporal().label(concise=True)._setup(t, Coordinate(), ax.xaxis)
        formatter = ax.xaxis.get_major_formatter()
        assert isinstance(formatter, mpl.dates.ConciseDateFormatter)


class TestBoolean:

    @pytest.fixture
    def x(self):
        return pd.Series([True, False, False, True], name="x", dtype=bool)

    def test_coordinate(self, x):

        s = Boolean()._setup(x, Coordinate())
        assert_array_equal(s(x), x.astype(float))

    def test_coordinate_axis(self, x):

        ax = mpl.figure.Figure().subplots()
        s = Boolean()._setup(x, Coordinate(), ax.xaxis)
        assert_array_equal(s(x), x.astype(float))
        f = ax.xaxis.get_major_formatter()
        assert f.format_ticks([0, 1]) == ["False", "True"]

    @pytest.mark.parametrize(
        "dtype,value",
        [
            (object, np.nan),
            (object, None),
            ("boolean", pd.NA),
        ]
    )
    def test_coordinate_missing(self, x, dtype, value):

        x = x.astype(dtype)
        x[2] = value
        s = Boolean()._setup(x, Coordinate())
        assert_array_equal(s(x), x.astype(float))

    def test_color_defaults(self, x):

        s = Boolean()._setup(x, Color())
        cs = color_palette()
        expected = [cs[int(x_i)] for x_i in ~x]
        assert_array_equal(s(x), expected)

    def test_color_list_palette(self, x):

        cs = color_palette("crest", 2)
        s = Boolean(cs)._setup(x, Color())
        expected = [cs[int(x_i)] for x_i in ~x]
        assert_array_equal(s(x), expected)

    def test_color_tuple_palette(self, x):

        cs = tuple(color_palette("crest", 2))
        s = Boolean(cs)._setup(x, Color())
        expected = [cs[int(x_i)] for x_i in ~x]
        assert_array_equal(s(x), expected)

    def test_color_dict_palette(self, x):

        cs = color_palette("crest", 2)
        pal = {True: cs[0], False: cs[1]}
        s = Boolean(pal)._setup(x, Color())
        expected = [pal[x_i] for x_i in x]
        assert_array_equal(s(x), expected)

    def test_object_defaults(self, x):

        vs = ["x", "y", "z"]

        class MockProperty(ObjectProperty):
            def _default_values(self, n):
                return vs[:n]

        s = Boolean()._setup(x, MockProperty())
        expected = [vs[int(x_i)] for x_i in ~x]
        assert s(x) == expected

    def test_object_list(self, x):

        vs = ["x", "y"]
        s = Boolean(vs)._setup(x, ObjectProperty())
        expected = [vs[int(x_i)] for x_i in ~x]
        assert s(x) == expected

    def test_object_dict(self, x):

        vs = {True: "x", False: "y"}
        s = Boolean(vs)._setup(x, ObjectProperty())
        expected = [vs[x_i] for x_i in x]
        assert s(x) == expected

    def test_fill(self, x):

        s = Boolean()._setup(x, Fill())
        assert_array_equal(s(x), x)

    def test_interval_defaults(self, x):

        vs = (1, 2)

        class MockProperty(IntervalProperty):
            _default_range = vs

        s = Boolean()._setup(x, MockProperty())
        expected = [vs[int(x_i)] for x_i in x]
        assert_array_equal(s(x), expected)

    def test_interval_tuple(self, x):

        vs = (3, 5)
        s = Boolean(vs)._setup(x, IntervalProperty())
        expected = [vs[int(x_i)] for x_i in x]
        assert_array_equal(s(x), expected)

    def test_finalize(self, x):

        ax = mpl.figure.Figure().subplots()
        s = Boolean()._setup(x, Coordinate(), ax.xaxis)
        s._finalize(Plot(), ax.xaxis)
        assert ax.get_xlim() == (1.5, -.5)
        assert_array_equal(ax.get_xticks(), [0, 1])
        assert ax.xaxis.major.formatter(0) == "False"
        assert ax.xaxis.major.formatter(1) == "True"


================================================
FILE: tests/_core/test_subplots.py
================================================
import itertools

import numpy as np
import pytest

from seaborn._core.subplots import Subplots


class TestSpecificationChecks:

    def test_both_facets_and_wrap(self):

        err = "Cannot wrap facets when specifying both `col` and `row`."
        facet_spec = {"wrap": 3, "variables": {"col": "a", "row": "b"}}
        with pytest.raises(RuntimeError, match=err):
            Subplots({}, facet_spec, {})

    def test_cross_xy_pairing_and_wrap(self):

        err = "Cannot wrap subplots when pairing on both `x` and `y`."
        pair_spec = {"wrap": 3, "structure": {"x": ["a", "b"], "y": ["y", "z"]}}
        with pytest.raises(RuntimeError, match=err):
            Subplots({}, {}, pair_spec)

    def test_col_facets_and_x_pairing(self):

        err = "Cannot facet the columns while pairing on `x`."
        facet_spec = {"variables": {"col": "a"}}
        pair_spec = {"structure": {"x": ["x", "y"]}}
        with pytest.raises(RuntimeError, match=err):
            Subplots({}, facet_spec, pair_spec)

    def test_wrapped_columns_and_y_pairing(self):

        err = "Cannot wrap the columns while pairing on `y`."
        facet_spec = {"variables": {"col": "a"}, "wrap": 2}
        pair_spec = {"structure": {"y": ["x", "y"]}}
        with pytest.raises(RuntimeError, match=err):
            Subplots({}, facet_spec, pair_spec)

    def test_wrapped_x_pairing_and_facetd_rows(self):

        err = "Cannot wrap the columns while faceting the rows."
        facet_spec = {"variables": {"row": "a"}}
        pair_spec = {"structure": {"x": ["x", "y"]}, "wrap": 2}
        with pytest.raises(RuntimeError, match=err):
            Subplots({}, facet_spec, pair_spec)


class TestSubplotSpec:

    def test_single_subplot(self):

        s = Subplots({}, {}, {})

        assert s.n_subplots == 1
        assert s.subplot_spec["ncols"] == 1
        assert s.subplot_spec["nrows"] == 1
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is True

    def test_single_facet(self):

        key = "a"
        order = list("abc")
        spec = {"variables": {"col": key}, "structure": {"col": order}}
        s = Subplots({}, spec, {})

        assert s.n_subplots == len(order)
        assert s.subplot_spec["ncols"] == len(order)
        assert s.subplot_spec["nrows"] == 1
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is True

    def test_two_facets(self):

        col_key = "a"
        row_key = "b"
        col_order = list("xy")
        row_order = list("xyz")
        spec = {
            "variables": {"col": col_key, "row": row_key},
            "structure": {"col": col_order, "row": row_order},

        }
        s = Subplots({}, spec, {})

        assert s.n_subplots == len(col_order) * len(row_order)
        assert s.subplot_spec["ncols"] == len(col_order)
        assert s.subplot_spec["nrows"] == len(row_order)
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is True

    def test_col_facet_wrapped(self):

        key = "b"
        wrap = 3
        order = list("abcde")
        spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap}
        s = Subplots({}, spec, {})

        assert s.n_subplots == len(order)
        assert s.subplot_spec["ncols"] == wrap
        assert s.subplot_spec["nrows"] == len(order) // wrap + 1
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is True

    def test_row_facet_wrapped(self):

        key = "b"
        wrap = 3
        order = list("abcde")
        spec = {"variables": {"row": key}, "structure": {"row": order}, "wrap": wrap}
        s = Subplots({}, spec, {})

        assert s.n_subplots == len(order)
        assert s.subplot_spec["ncols"] == len(order) // wrap + 1
        assert s.subplot_spec["nrows"] == wrap
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is True

    def test_col_facet_wrapped_single_row(self):

        key = "b"
        order = list("abc")
        wrap = len(order) + 2
        spec = {"variables": {"col": key}, "structure": {"col": order}, "wrap": wrap}
        s = Subplots({}, spec, {})

        assert s.n_subplots == len(order)
        assert s.subplot_spec["ncols"] == len(order)
        assert s.subplot_spec["nrows"] == 1
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is True

    def test_x_and_y_paired(self):

        x = ["x", "y", "z"]
        y = ["a", "b"]
        s = Subplots({}, {}, {"structure": {"x": x, "y": y}})

        assert s.n_subplots == len(x) * len(y)
        assert s.subplot_spec["ncols"] == len(x)
        assert s.subplot_spec["nrows"] == len(y)
        assert s.subplot_spec["sharex"] == "col"
        assert s.subplot_spec["sharey"] == "row"

    def test_x_paired(self):

        x = ["x", "y", "z"]
        s = Subplots({}, {}, {"structure": {"x": x}})

        assert s.n_subplots == len(x)
        assert s.subplot_spec["ncols"] == len(x)
        assert s.subplot_spec["nrows"] == 1
        assert s.subplot_spec["sharex"] == "col"
        assert s.subplot_spec["sharey"] is True

    def test_y_paired(self):

        y = ["x", "y", "z"]
        s = Subplots({}, {}, {"structure": {"y": y}})

        assert s.n_subplots == len(y)
        assert s.subplot_spec["ncols"] == 1
        assert s.subplot_spec["nrows"] == len(y)
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] == "row"

    def test_x_paired_and_wrapped(self):

        x = ["a", "b", "x", "y", "z"]
        wrap = 3
        s = Subplots({}, {}, {"structure": {"x": x}, "wrap": wrap})

        assert s.n_subplots == len(x)
        assert s.subplot_spec["ncols"] == wrap
        assert s.subplot_spec["nrows"] == len(x) // wrap + 1
        assert s.subplot_spec["sharex"] is False
        assert s.subplot_spec["sharey"] is True

    def test_y_paired_and_wrapped(self):

        y = ["a", "b", "x", "y", "z"]
        wrap = 2
        s = Subplots({}, {}, {"structure": {"y": y}, "wrap": wrap})

        assert s.n_subplots == len(y)
        assert s.subplot_spec["ncols"] == len(y) // wrap + 1
        assert s.subplot_spec["nrows"] == wrap
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is False

    def test_y_paired_and_wrapped_single_row(self):

        y = ["x", "y", "z"]
        wrap = 1
        s = Subplots({}, {}, {"structure": {"y": y}, "wrap": wrap})

        assert s.n_subplots == len(y)
        assert s.subplot_spec["ncols"] == len(y)
        assert s.subplot_spec["nrows"] == 1
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] is False

    def test_col_faceted_y_paired(self):

        y = ["x", "y", "z"]
        key = "a"
        order = list("abc")
        facet_spec = {"variables": {"col": key}, "structure": {"col": order}}
        pair_spec = {"structure": {"y": y}}
        s = Subplots({}, facet_spec, pair_spec)

        assert s.n_subplots == len(order) * len(y)
        assert s.subplot_spec["ncols"] == len(order)
        assert s.subplot_spec["nrows"] == len(y)
        assert s.subplot_spec["sharex"] is True
        assert s.subplot_spec["sharey"] == "row"

    def test_row_faceted_x_paired(self):

        x = ["f", "s"]
        key = "a"
        order = list("abc")
        facet_spec = {"variables": {"row": key}, "structure": {"row": order}}
        pair_spec = {"structure": {"x": x}}
        s = Subplots({}, facet_spec, pair_spec)

        assert s.n_subplots == len(order) * len(x)
        assert s.subplot_spec["ncols"] == len(x)
        assert s.subplot_spec["nrows"] == len(order)
        assert s.subplot_spec["sharex"] == "col"
        assert s.subplot_spec["sharey"] is True

    def test_x_any_y_paired_non_cross(self):

        x = ["a", "b", "c"]
        y = ["x", "y", "z"]
        spec = {"structure": {"x": x, "y": y}, "cross": False}
        s = Subplots({}, {}, spec)

        assert s.n_subplots == len(x)
        assert s.subplot_spec["ncols"] == len(y)
        assert s.subplot_spec["nrows"] == 1
        assert s.subplot_spec["sharex"] is False
        assert s.subplot_spec["sharey"] is False

    def test_x_any_y_paired_non_cross_wrapped(self):

        x = ["a", "b", "c"]
        y = ["x", "y", "z"]
        wrap = 2
        spec = {"structure": {"x": x, "y": y}, "cross": False, "wrap": wrap}
        s = Subplots({}, {}, spec)

        assert s.n_subplots == len(x)
        assert s.subplot_spec["ncols"] == wrap
        assert s.subplot_spec["nrows"] == len(x) // wrap + 1
        assert s.subplot_spec["sharex"] is False
        assert s.subplot_spec["sharey"] is False

    def test_forced_unshared_facets(self):

        s = Subplots({"sharex": False, "sharey": "row"}, {}, {})
        assert s.subplot_spec["sharex"] is False
        assert s.subplot_spec["sharey"] == "row"


class TestSubplotElements:

    def test_single_subplot(self):

        s = Subplots({}, {}, {})
        f = s.init_figure({}, {})

        assert len(s) == 1
        for i, e in enumerate(s):
            for side in ["left", "right", "bottom", "top"]:
                assert e[side]
            for dim in ["col", "row"]:
                assert e[dim] is None
            for axis in "xy":
                assert e[axis] == axis
            assert e["ax"] == f.axes[i]

    @pytest.mark.parametrize("dim", ["col", "row"])
    def test_single_facet_dim(self, dim):

        key = "a"
        order = list("abc")
        spec = {"variables": {dim: key}, "structure": {dim: order}}
        s = Subplots({}, spec, {})
        s.init_figure(spec, {})

        assert len(s) == len(order)

        for i, e in enumerate(s):
            assert e[dim] == order[i]
            for axis in "xy":
                assert e[axis] == axis
            assert e["top"] == (dim == "col" or i == 0)
            assert e["bottom"] == (dim == "col" or i == len(order) - 1)
            assert e["left"] == (dim == "row" or i == 0)
            assert e["right"] == (dim == "row" or i == len(order) - 1)

    @pytest.mark.parametrize("dim", ["col", "row"])
    def test_single_facet_dim_wrapped(self, dim):

        key = "b"
        order = list("abc")
        wrap = len(order) - 1
        spec = {"variables": {dim: key}, "structure": {dim: order}, "wrap": wrap}
        s = Subplots({}, spec, {})
        s.init_figure(spec, {})

        assert len(s) == len(order)

        for i, e in enumerate(s):
            assert e[dim] == order[i]
            for axis in "xy":
                assert e[axis] == axis

            sides = {
                "col": ["top", "bottom", "left", "right"],
                "row": ["left", "right", "top", "bottom"],
            }
            tests = (
                i < wrap,
                i >= wrap or i >= len(s) % wrap,
                i % wrap == 0,
                i % wrap == wrap - 1 or i + 1 == len(s),
            )

            for side, expected in zip(sides[dim], tests):
                assert e[side] == expected

    def test_both_facet_dims(self):

        col = "a"
        row = "b"
        col_order = list("ab")
        row_order = list("xyz")
        facet_spec = {
            "variables": {"col": col, "row": row},
            "structure": {"col": col_order, "row": row_order},
        }
        s = Subplots({}, facet_spec, {})
        s.init_figure(facet_spec, {})

        n_cols = len(col_order)
        n_rows = len(row_order)
        assert len(s) == n_cols * n_rows
        es = list(s)

        for e in es[:n_cols]:
            assert e["top"]
        for e in es[::n_cols]:
            assert e["left"]
        for e in es[n_cols - 1::n_cols]:
            assert e["right"]
        for e in es[-n_cols:]:
            assert e["bottom"]

        for e, (row_, col_) in zip(es, itertools.product(row_order, col_order)):
            assert e["col"] == col_
            assert e["row"] == row_

        for e in es:
            assert e["x"] == "x"
            assert e["y"] == "y"

    @pytest.mark.parametrize("var", ["x", "y"])
    def test_single_paired_var(self, var):

        other_var = {"x": "y", "y": "x"}[var]
        pairings = ["x", "y", "z"]
        pair_spec = {
            "variables": {f"{var}{i}": v for i, v in enumerate(pairings)},
            "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]},
        }

        s = Subplots({}, {}, pair_spec)
        s.init_figure(pair_spec)

        assert len(s) == len(pair_spec["structure"][var])

        for i, e in enumerate(s):
            assert e[var] == f"{var}{i}"
            assert e[other_var] == other_var
            assert e["col"] is e["row"] is None

        tests = i == 0, True, True, i == len(s) - 1
        sides = {
            "x": ["left", "right", "top", "bottom"],
            "y": ["top", "bottom", "left", "right"],
        }

        for side, expected in zip(sides[var], tests):
            assert e[side] == expected

    @pytest.mark.parametrize("var", ["x", "y"])
    def test_single_paired_var_wrapped(self, var):

        other_var = {"x": "y", "y": "x"}[var]
        pairings = ["x", "y", "z", "a", "b"]
        wrap = len(pairings) - 2
        pair_spec = {
            "variables": {f"{var}{i}": val for i, val in enumerate(pairings)},
            "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]},
            "wrap": wrap
        }
        s = Subplots({}, {}, pair_spec)
        s.init_figure(pair_spec)

        assert len(s) == len(pairings)

        for i, e in enumerate(s):
            assert e[var] == f"{var}{i}"
            assert e[other_var] == other_var
            assert e["col"] is e["row"] is None

            tests = (
                i < wrap,
                i >= wrap or i >= len(s) % wrap,
                i % wrap == 0,
                i % wrap == wrap - 1 or i + 1 == len(s),
            )
            sides = {
                "x": ["top", "bottom", "left", "right"],
                "y": ["left", "right", "top", "bottom"],
            }
            for side, expected in zip(sides[var], tests):
                assert e[side] == expected

    def test_both_paired_variables(self):

        x = ["x0", "x1"]
        y = ["y0", "y1", "y2"]
        pair_spec = {"structure": {"x": x, "y": y}}
        s = Subplots({}, {}, pair_spec)
        s.init_figure(pair_spec)

        n_cols = len(x)
        n_rows = len(y)
        assert len(s) == n_cols * n_rows
        es = list(s)

        for e in es[:n_cols]:
            assert e["top"]
        for e in es[::n_cols]:
            assert e["left"]
        for e in es[n_cols - 1::n_cols]:
            assert e["right"]
        for e in es[-n_cols:]:
            assert e["bottom"]

        for e in es:
            assert e["col"] is e["row"] is None

        for i in range(len(y)):
            for j in range(len(x)):
                e = es[i * len(x) + j]
                assert e["x"] == f"x{j}"
                assert e["y"] == f"y{i}"

    def test_both_paired_non_cross(self):

        pair_spec = {
            "structure": {"x": ["x0", "x1", "x2"], "y": ["y0", "y1", "y2"]},
            "cross": False
        }
        s = Subplots({}, {}, pair_spec)
        s.init_figure(pair_spec)

        for i, e in enumerate(s):
            assert e["x"] == f"x{i}"
            assert e["y"] == f"y{i}"
            assert e["col"] is e["row"] is None
            assert e["left"] == (i == 0)
            assert e["right"] == (i == (len(s) - 1))
            assert e["top"]
            assert e["bottom"]

    @pytest.mark.parametrize("dim,var", [("col", "y"), ("row", "x")])
    def test_one_facet_one_paired(self, dim, var):

        other_var = {"x": "y", "y": "x"}[var]
        other_dim = {"col": "row", "row": "col"}[dim]
        order = list("abc")
        facet_spec = {"variables": {dim: "s"}, "structure": {dim: order}}

        pairings = ["x", "y", "t"]
        pair_spec = {
            "variables": {f"{var}{i}": val for i, val in enumerate(pairings)},
            "structure": {var: [f"{var}{i}" for i, _ in enumerate(pairings)]},
        }

        s = Subplots({}, facet_spec, pair_spec)
        s.init_figure(pair_spec)

        n_cols = len(order) if dim == "col" else len(pairings)
        n_rows = len(order) if dim == "row" else len(pairings)

        assert len(s) == len(order) * len(pairings)

        es = list(s)

        for e in es[:n_cols]:
            assert e["top"]
        for e in es[::n_cols]:
            assert e["left"]
        for e in es[n_cols - 1::n_cols]:
            assert e["right"]
        for e in es[-n_cols:]:
            assert e["bottom"]

        if dim == "row":
            es = np.reshape(es, (n_rows, n_cols)).T.ravel()

        for i, e in enumerate(es):
            assert e[dim] == order[i % len(pairings)]
            assert e[other_dim] is None
            assert e[var] == f"{var}{i // len(order)}"
            assert e[other_var] == other_var


================================================
FILE: tests/_marks/__init__.py
================================================


================================================
FILE: tests/_marks/test_area.py
================================================

import matplotlib as mpl
from matplotlib.colors import to_rgba, to_rgba_array

from numpy.testing import assert_array_equal

from seaborn._core.plot import Plot
from seaborn._marks.area import Area, Band


class TestArea:

    def test_single_defaults(self):

        x, y = [1, 2, 3], [1, 2, 1]
        p = Plot(x=x, y=y).add(Area()).plot()
        ax = p._figure.axes[0]
        poly = ax.patches[0]
        verts = poly.get_path().vertices.T
        colors = p._theme["axes.prop_cycle"].by_key()["color"]

        expected_x = [1, 2, 3, 3, 2, 1, 1]
        assert_array_equal(verts[0], expected_x)

        expected_y = [0, 0, 0, 1, 2, 1, 0]
        assert_array_equal(verts[1], expected_y)

        fc = poly.get_facecolor()
        assert_array_equal(fc, to_rgba(colors[0], .2))

        ec = poly.get_edgecolor()
        assert_array_equal(ec, to_rgba(colors[0], 1))

        lw = poly.get_linewidth()
        assert_array_equal(lw, mpl.rcParams["patch.linewidth"] * 2)

    def test_set_properties(self):

        x, y = [1, 2, 3], [1, 2, 1]
        mark = Area(
            color=".33",
            alpha=.3,
            edgecolor=".88",
            edgealpha=.8,
            edgewidth=2,
            edgestyle=(0, (2, 1)),
        )
        p = Plot(x=x, y=y).add(mark).plot()
        ax = p._figure.axes[0]
        poly = ax.patches[0]

        fc = poly.get_facecolor()
        assert_array_equal(fc, to_rgba(mark.color, mark.alpha))

        ec = poly.get_edgecolor()
        assert_array_equal(ec, to_rgba(mark.edgecolor, mark.edgealpha))

        lw = poly.get_linewidth()
        assert_array_equal(lw, mark.edgewidth * 2)

        ls = poly.get_linestyle()
        dash_on, dash_off = mark.edgestyle[1]
        expected = (0, (mark.edgewidth * dash_on / 4, mark.edgewidth * dash_off / 4))
        assert ls == expected

    def test_mapped_properties(self):

        x, y = [1, 2, 3, 2, 3, 4], [1, 2, 1, 1, 3, 2]
        g = ["a", "a", "a", "b", "b", "b"]
        cs = [".2", ".8"]
        p = Plot(x=x, y=y, color=g, edgewidth=g).scale(color=cs).add(Area()).plot()
        ax = p._figure.axes[0]

        expected_x = [1, 2, 3, 3, 2, 1, 1], [2, 3, 4, 4, 3, 2, 2]
        expected_y = [0, 0, 0, 1, 2, 1, 0], [0, 0, 0, 2, 3, 1, 0]

        for i, poly in enumerate(ax.patches):
            verts = poly.get_path().vertices.T
            assert_array_equal(verts[0], expected_x[i])
            assert_array_equal(verts[1], expected_y[i])

        fcs = [p.get_facecolor() for p in ax.patches]
        assert_array_equal(fcs, to_rgba_array(cs, .2))

        ecs = [p.get_edgecolor() for p in ax.patches]
        assert_array_equal(ecs, to_rgba_array(cs, 1))

        lws = [p.get_linewidth() for p in ax.patches]
        assert lws[0] > lws[1]

    def test_unfilled(self):

        x, y = [1, 2, 3], [1, 2, 1]
        c = ".5"
        p = Plot(x=x, y=y).add(Area(fill=False, color=c)).plot()
        ax = p._figure.axes[0]
        poly = ax.patches[0]
        assert poly.get_facecolor() == to_rgba(c, 0)


class TestBand:

    def test_range(self):

        x, ymin, ymax = [1, 2, 4], [2, 1, 4], [3, 3, 5]
        p = Plot(x=x, ymin=ymin, ymax=ymax).add(Band()).plot()
        ax = p._figure.axes[0]
        verts = ax.patches[0].get_path().vertices.T

        expected_x = [1, 2, 4, 4, 2, 1, 1]
        assert_array_equal(verts[0], expected_x)

        expected_y = [2, 1, 4, 5, 3, 3, 2]
        assert_array_equal(verts[1], expected_y)

    def test_auto_range(self):

        x = [1, 1, 2, 2, 2]
        y = [1, 2, 3, 4, 5]
        p = Plot(x=x, y=y).add(Band()).plot()
        ax = p._figure.axes[0]
        verts = ax.patches[0].get_path().vertices.T

        expected_x = [1, 2, 2, 1, 1]
        assert_array_equal(verts[0], expected_x)

        expected_y = [1, 3, 5, 2, 1]
        assert_array_equal(verts[1], expected_y)


================================================
FILE: tests/_marks/test_bar.py
================================================

import numpy as np
import pandas as pd
from matplotlib.colors import to_rgba, to_rgba_array

import pytest
from numpy.testing import assert_array_equal

from seaborn._core.plot import Plot
from seaborn._marks.bar import Bar, Bars


class TestBar:

    def plot_bars(self, variables, mark_kws, layer_kws):

        p = Plot(**variables).add(Bar(**mark_kws), **layer_kws).plot()
        ax = p._figure.axes[0]
        return [bar for barlist in ax.containers for bar in barlist]

    def check_bar(self, bar, x, y, width, height):

        assert bar.get_x() == pytest.approx(x)
        assert bar.get_y() == pytest.approx(y)
        assert bar.get_width() == pytest.approx(width)
        assert bar.get_height() == pytest.approx(height)

    def test_categorical_positions_vertical(self):

        x = ["a", "b"]
        y = [1, 2]
        w = .8
        bars = self.plot_bars({"x": x, "y": y}, {}, {})
        for i, bar in enumerate(bars):
            self.check_bar(bar, i - w / 2, 0, w, y[i])

    def test_categorical_positions_horizontal(self):

        x = [1, 2]
        y = ["a", "b"]
        w = .8
        bars = self.plot_bars({"x": x, "y": y}, {}, {})
        for i, bar in enumerate(bars):
            self.check_bar(bar, 0, i - w / 2, x[i], w)

    def test_numeric_positions_vertical(self):

        x = [1, 2]
        y = [3, 4]
        w = .8
        bars = self.plot_bars({"x": x, "y": y}, {}, {})
        for i, bar in enumerate(bars):
            self.check_bar(bar, x[i] - w / 2, 0, w, y[i])

    def test_numeric_positions_horizontal(self):

        x = [1, 2]
        y = [3, 4]
        w = .8
        bars = self.plot_bars({"x": x, "y": y}, {}, {"orient": "h"})
        for i, bar in enumerate(bars):
            self.check_bar(bar, 0, y[i] - w / 2, x[i], w)

    def test_set_properties(self):

        x = ["a", "b", "c"]
        y = [1, 3, 2]

        mark = Bar(
            color=".8",
            alpha=.5,
            edgecolor=".3",
            edgealpha=.9,
            edgestyle=(2, 1),
            edgewidth=1.5,
        )

        p = Plot(x, y).add(mark).plot()
        ax = p._figure.axes[0]
        for bar in ax.patches:
            assert bar.get_facecolor() == to_rgba(mark.color, mark.alpha)
            assert bar.get_edgecolor() == to_rgba(mark.edgecolor, mark.edgealpha)
            # See comments in plotting method for why we need these adjustments
            assert bar.get_linewidth() == mark.edgewidth * 2
            expected_dashes = (mark.edgestyle[0] / 2, mark.edgestyle[1] / 2)
            assert bar.get_linestyle() == (0, expected_dashes)

    def test_mapped_properties(self):

        x = ["a", "b"]
        y = [1, 2]
        mark = Bar(alpha=.2)
        p = Plot(x, y, color=x, edgewidth=y).add(mark).plot()
        ax = p._figure.axes[0]
        colors = p._theme["axes.prop_cycle"].by_key()["color"]
        for i, bar in enumerate(ax.patches):
            assert bar.get_facecolor() == to_rgba(colors[i], mark.alpha)
            assert bar.get_edgecolor() == to_rgba(colors[i], 1)
        assert ax.patches[0].get_linewidth() < ax.patches[1].get_linewidth()

    def test_zero_height_skipped(self):

        p = Plot(["a", "b", "c"], [1, 0, 2]).add(Bar()).plot()
        ax = p._figure.axes[0]
        assert len(ax.patches) == 2

    def test_artist_kws_clip(self):

        p = Plot(["a", "b"], [1, 2]).add(Bar({"clip_on": False})).plot()
        patch = p._figure.axes[0].patches[0]
        assert patch.clipbox is None


class TestBars:

    @pytest.fixture
    def x(self):
        return pd.Series([4, 5, 6, 7, 8], name="x")

    @pytest.fixture
    def y(self):
        return pd.Series([2, 8, 3, 5, 9], name="y")

    @pytest.fixture
    def color(self):
        return pd.Series(["a", "b", "c", "a", "c"], name="color")

    def test_positions(self, x, y):

        p = Plot(x, y).add(Bars()).plot()
        ax = p._figure.axes[0]
        paths = ax.collections[0].get_paths()
        assert len(paths) == len(x)
        for i, path in enumerate(paths):
            verts = path.vertices
            assert verts[0, 0] == pytest.approx(x[i] - .5)
            assert verts[1, 0] == pytest.approx(x[i] + .5)
            assert verts[0, 1] == 0
            assert verts[3, 1] == y[i]

    def test_positions_horizontal(self, x, y):

        p = Plot(x=y, y=x).add(Bars(), orient="h").plot()
        ax = p._figure.axes[0]
        paths = ax.collections[0].get_paths()
        assert len(paths) == len(x)
        for i, path in enumerate(paths):
            verts = path.vertices
            assert verts[0, 1] == pytest.approx(x[i] - .5)
            assert verts[3, 1] == pytest.approx(x[i] + .5)
            assert verts[0, 0] == 0
            assert verts[1, 0] == y[i]

    def test_width(self, x, y):

        p = Plot(x, y).add(Bars(width=.4)).plot()
        ax = p._figure.axes[0]
        paths = ax.collections[0].get_paths()
        for i, path in enumerate(paths):
            verts = path.vertices
            assert verts[0, 0] == pytest.approx(x[i] - .2)
            assert verts[1, 0] == pytest.approx(x[i] + .2)

    def test_mapped_color_direct_alpha(self, x, y, color):

        alpha = .5
        p = Plot(x, y, color=color).add(Bars(alpha=alpha)).plot()
        ax = p._figure.axes[0]
        fcs = ax.collections[0].get_facecolors()
        C0, C1, C2, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
        expected = to_rgba_array([C0, C1, C2, C0, C2], alpha)
        assert_array_equal(fcs, expected)

    def test_mapped_edgewidth(self, x, y):

        p = Plot(x, y, edgewidth=y).add(Bars()).plot()
        ax = p._figure.axes[0]
        lws = ax.collections[0].get_linewidths()
        assert_array_equal(np.argsort(lws), np.argsort(y))

    def test_auto_edgewidth(self):

        x0 = np.arange(10)
        x1 = np.arange(1000)

        p0 = Plot(x0, x0).add(Bars()).plot()
        p1 = Plot(x1, x1).add(Bars()).plot()

        lw0 = p0._figure.axes[0].collections[0].get_linewidths()
        lw1 = p1._figure.axes[0].collections[0].get_linewidths()

        assert (lw0 > lw1).all()

    def test_unfilled(self, x, y):

        p = Plot(x, y).add(Bars(fill=False, edgecolor="C4")).plot()
        ax = p._figure.axes[0]
        fcs = ax.collections[0].get_facecolors()
        ecs = ax.collections[0].get_edgecolors()
        colors = p._theme["axes.prop_cycle"].by_key()["color"]
        assert_array_equal(fcs, to_rgba_array([colors[0]] * len(x), 0))
        assert_array_equal(ecs, to_rgba_array([colors[4]] * len(x), 1))

    def test_log_scale(self):

        x = y = [1, 10, 100, 1000]
        p = Plot(x, y).add(Bars()).scale(x="log").plot()
        ax = p._figure.axes[0]

        paths = ax.collections[0].get_paths()
        for a, b in zip(paths, paths[1:]):
            assert a.vertices[1, 0] == pytest.approx(b.vertices[0, 0])


================================================
FILE: tests/_marks/test_base.py
================================================
from dataclasses import dataclass

import numpy as np
import pandas as pd
import matplotlib as mpl

import pytest
from numpy.testing import assert_array_equal

from seaborn._marks.base import Mark, Mappable, resolve_color


class TestMappable:

    def mark(self, **features):

        @dataclass
        class MockMark(Mark):
            linewidth: float = Mappable(rc="lines.linewidth")
            pointsize: float = Mappable(4)
            color: str = Mappable("C0")
            fillcolor: str = Mappable(depend="color")
            alpha: float = Mappable(1)
            fillalpha: float = Mappable(depend="alpha")

        m = MockMark(**features)
        return m

    def test_repr(self):

        assert str(Mappable(.5)) == "<0.5>"
        assert str(Mappable("CO")) == "<'CO'>"
        assert str(Mappable(rc="lines.linewidth")) == ""
        assert str(Mappable(depend="color")) == ""
        assert str(Mappable(auto=True)) == ""

    def test_input_checks(self):

        with pytest.raises(AssertionError):
            Mappable(rc="bogus.parameter")
        with pytest.raises(AssertionError):
            Mappable(depend="nonexistent_feature")

    def test_value(self):

        val = 3
        m = self.mark(linewidth=val)
        assert m._resolve({}, "linewidth") == val

        df = pd.DataFrame(index=pd.RangeIndex(10))
        assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))

    def test_default(self):

        val = 3
        m = self.mark(linewidth=Mappable(val))
        assert m._resolve({}, "linewidth") == val

        df = pd.DataFrame(index=pd.RangeIndex(10))
        assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))

    def test_rcparam(self):

        param = "lines.linewidth"
        val = mpl.rcParams[param]

        m = self.mark(linewidth=Mappable(rc=param))
        assert m._resolve({}, "linewidth") == val

        df = pd.DataFrame(index=pd.RangeIndex(10))
        assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))

    def test_depends(self):

        val = 2
        df = pd.DataFrame(index=pd.RangeIndex(10))

        m = self.mark(pointsize=Mappable(val), linewidth=Mappable(depend="pointsize"))
        assert m._resolve({}, "linewidth") == val
        assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val))

        m = self.mark(pointsize=val * 2, linewidth=Mappable(depend="pointsize"))
        assert m._resolve({}, "linewidth") == val * 2
        assert_array_equal(m._resolve(df, "linewidth"), np.full(len(df), val * 2))

    def test_mapped(self):

        values = {"a": 1, "b": 2, "c": 3}

        def f(x):
            return np.array([values[x_i] for x_i in x])

        m = self.mark(linewidth=Mappable(2))
        scales = {"linewidth": f}

        assert m._resolve({"linewidth": "c"}, "linewidth", scales) == 3

        df = pd.DataFrame({"linewidth": ["a", "b", "c"]})
        expected = np.array([1, 2, 3], float)
        assert_array_equal(m._resolve(df, "linewidth", scales), expected)

    def test_color(self):

        c, a = "C1", .5
        m = self.mark(color=c, alpha=a)

        assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a)

        df = pd.DataFrame(index=pd.RangeIndex(10))
        cs = [c] * len(df)
        assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a))

    def test_color_mapped_alpha(self):

        c = "r"
        values = {"a": .2, "b": .5, "c": .8}

        m = self.mark(color=c, alpha=Mappable(1))
        scales = {"alpha": lambda s: np.array([values[s_i] for s_i in s])}

        assert resolve_color(m, {"alpha": "b"}, "", scales) == mpl.colors.to_rgba(c, .5)

        df = pd.DataFrame({"alpha": list(values.keys())})

        # Do this in two steps for mpl 3.2 compat
        expected = mpl.colors.to_rgba_array([c] * len(df))
        expected[:, 3] = list(values.values())

        assert_array_equal(resolve_color(m, df, "", scales), expected)

    def test_color_scaled_as_strings(self):

        colors = ["C1", "dodgerblue", "#445566"]
        m = self.mark()
        scales = {"color": lambda s: colors}

        actual = resolve_color(m, {"color": pd.Series(["a", "b", "c"])}, "", scales)
        expected = mpl.colors.to_rgba_array(colors)
        assert_array_equal(actual, expected)

    def test_fillcolor(self):

        c, a = "green", .8
        fa = .2
        m = self.mark(
            color=c, alpha=a,
            fillcolor=Mappable(depend="color"), fillalpha=Mappable(fa),
        )

        assert resolve_color(m, {}) == mpl.colors.to_rgba(c, a)
        assert resolve_color(m, {}, "fill") == mpl.colors.to_rgba(c, fa)

        df = pd.DataFrame(index=pd.RangeIndex(10))
        cs = [c] * len(df)
        assert_array_equal(resolve_color(m, df), mpl.colors.to_rgba_array(cs, a))
        assert_array_equal(
            resolve_color(m, df, "fill"), mpl.colors.to_rgba_array(cs, fa)
        )


================================================
FILE: tests/_marks/test_dot.py
================================================
from matplotlib.colors import to_rgba, to_rgba_array

import pytest
from numpy.testing import assert_array_equal

from seaborn.palettes import color_palette
from seaborn._core.plot import Plot
from seaborn._marks.dot import Dot, Dots


@pytest.fixture(autouse=True)
def default_palette():
    with color_palette("deep"):
        yield


class DotBase:

    def check_offsets(self, points, x, y):

        offsets = points.get_offsets().T
        assert_array_equal(offsets[0], x)
        assert_array_equal(offsets[1], y)

    def check_colors(self, part, points, colors, alpha=None):

        rgba = to_rgba_array(colors, alpha)

        getter = getattr(points, f"get_{part}colors")
        assert_array_equal(getter(), rgba)


class TestDot(DotBase):

    def test_simple(self):

        x = [1, 2, 3]
        y = [4, 5, 2]
        p = Plot(x=x, y=y).add(Dot()).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        C0, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
        self.check_offsets(points, x, y)
        self.check_colors("face", points, [C0] * 3, 1)
        self.check_colors("edge", points, [C0] * 3, 1)

    def test_filled_unfilled_mix(self):

        x = [1, 2]
        y = [4, 5]
        marker = ["a", "b"]
        shapes = ["o", "x"]

        mark = Dot(edgecolor="w", stroke=2, edgewidth=1)
        p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        C0, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
        self.check_offsets(points, x, y)
        self.check_colors("face", points, [C0, to_rgba(C0, 0)], None)
        self.check_colors("edge", points, ["w", C0], 1)

        expected = [mark.edgewidth, mark.stroke]
        assert_array_equal(points.get_linewidths(), expected)

    def test_missing_coordinate_data(self):

        x = [1, float("nan"), 3]
        y = [5, 3, 4]

        p = Plot(x=x, y=y).add(Dot()).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        self.check_offsets(points, [1, 3], [5, 4])

    @pytest.mark.parametrize("prop", ["color", "fill", "marker", "pointsize"])
    def test_missing_semantic_data(self, prop):

        x = [1, 2, 3]
        y = [5, 3, 4]
        z = ["a", float("nan"), "b"]

        p = Plot(x=x, y=y, **{prop: z}).add(Dot()).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        self.check_offsets(points, [1, 3], [5, 4])


class TestDots(DotBase):

    def test_simple(self):

        x = [1, 2, 3]
        y = [4, 5, 2]
        p = Plot(x=x, y=y).add(Dots()).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        C0, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
        self.check_offsets(points, x, y)
        self.check_colors("face", points, [C0] * 3, .2)
        self.check_colors("edge", points, [C0] * 3, 1)

    def test_set_color(self):

        x = [1, 2, 3]
        y = [4, 5, 2]
        m = Dots(color=".25")
        p = Plot(x=x, y=y).add(m).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        self.check_offsets(points, x, y)
        self.check_colors("face", points, [m.color] * 3, .2)
        self.check_colors("edge", points, [m.color] * 3, 1)

    def test_map_color(self):

        x = [1, 2, 3]
        y = [4, 5, 2]
        c = ["a", "b", "a"]
        p = Plot(x=x, y=y, color=c).add(Dots()).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        C0, C1, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
        self.check_offsets(points, x, y)
        self.check_colors("face", points, [C0, C1, C0], .2)
        self.check_colors("edge", points, [C0, C1, C0], 1)

    def test_fill(self):

        x = [1, 2, 3]
        y = [4, 5, 2]
        c = ["a", "b", "a"]
        p = Plot(x=x, y=y, color=c).add(Dots(fill=False)).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        C0, C1, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
        self.check_offsets(points, x, y)
        self.check_colors("face", points, [C0, C1, C0], 0)
        self.check_colors("edge", points, [C0, C1, C0], 1)

    def test_pointsize(self):

        x = [1, 2, 3]
        y = [4, 5, 2]
        s = 3
        p = Plot(x=x, y=y).add(Dots(pointsize=s)).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        self.check_offsets(points, x, y)
        assert_array_equal(points.get_sizes(), [s ** 2] * 3)

    def test_stroke(self):

        x = [1, 2, 3]
        y = [4, 5, 2]
        s = 3
        p = Plot(x=x, y=y).add(Dots(stroke=s)).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        self.check_offsets(points, x, y)
        assert_array_equal(points.get_linewidths(), [s] * 3)

    def test_filled_unfilled_mix(self):

        x = [1, 2]
        y = [4, 5]
        marker = ["a", "b"]
        shapes = ["o", "x"]

        mark = Dots(stroke=2)
        p = Plot(x=x, y=y).add(mark, marker=marker).scale(marker=shapes).plot()
        ax = p._figure.axes[0]
        points, = ax.collections
        C0, C1, *_ = p._theme["axes.prop_cycle"].by_key()["color"]
        self.check_offsets(points, x, y)
        self.check_colors("face", points, [to_rgba(C0, .2), to_rgba(C0, 0)], None)
        self.check_colors("edge", points, [C0, C0], 1)
        assert_array_equal(points.get_linewidths(), [mark.stroke] * 2)


================================================
FILE: tests/_marks/test_line.py
================================================

import numpy as np
import matplotlib as mpl
from matplotlib.colors import same_color, to_rgba

from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn._core.plot import Plot
from seaborn._core.moves import Dodge
from seaborn._marks.line import Dash, Line, Path, Lines, Paths, Range


class TestPath:

    def test_xy_data(self):

        x = [1, 5, 3, np.nan, 2]
        y = [1, 4, 2, 5, 3]
        g = [1, 2, 1, 1, 2]
        p = Plot(x=x, y=y, group=g).add(Path()).plot()
        line1, line2 = p._figure.axes[0].get_lines()

        assert_array_equal(line1.get_xdata(), [1, 3, np.nan])
        assert_array_equal(line1.get_ydata(), [1, 2, np.nan])
        assert_array_equal(line2.get_xdata(), [5, 2])
        assert_array_equal(line2.get_ydata(), [4, 3])

    def test_shared_colors_direct(self):

        x = y = [1, 2, 3]
        color = ".44"
        m = Path(color=color)
        p = Plot(x=x, y=y).add(m).plot()
        line, = p._figure.axes[0].get_lines()
        assert same_color(line.get_color(), color)
        assert same_color(line.get_markeredgecolor(), color)
        assert same_color(line.get_markerfacecolor(), color)

    def test_separate_colors_direct(self):

        x = y = [1, 2, 3]
        y = [1, 2, 3]
        m = Path(color=".22", edgecolor=".55", fillcolor=".77")
        p = Plot(x=x, y=y).add(m).plot()
        line, = p._figure.axes[0].get_lines()
        assert same_color(line.get_color(), m.color)
        assert same_color(line.get_markeredgecolor(), m.edgecolor)
        assert same_color(line.get_markerfacecolor(), m.fillcolor)

    def test_shared_colors_mapped(self):

        x = y = [1, 2, 3, 4]
        c = ["a", "a", "b", "b"]
        m = Path()
        p = Plot(x=x, y=y, color=c).add(m).plot()
        ax = p._figure.axes[0]
        colors = p._theme["axes.prop_cycle"].by_key()["color"]
        for i, line in enumerate(ax.get_lines()):
            assert same_color(line.get_color(), colors[i])
            assert same_color(line.get_markeredgecolor(), colors[i])
            assert same_color(line.get_markerfacecolor(), colors[i])

    def test_separate_colors_mapped(self):

        x = y = [1, 2, 3, 4]
        c = ["a", "a", "b", "b"]
        d = ["x", "y", "x", "y"]
        m = Path()
        p = Plot(x=x, y=y, color=c, fillcolor=d).add(m).plot()
        ax = p._figure.axes[0]
        colors = p._theme["axes.prop_cycle"].by_key()["color"]
        for i, line in enumerate(ax.get_lines()):
            assert same_color(line.get_color(), colors[i // 2])
            assert same_color(line.get_markeredgecolor(), colors[i // 2])
            assert same_color(line.get_markerfacecolor(), colors[i % 2])

    def test_color_with_alpha(self):

        x = y = [1, 2, 3]
        m = Path(color=(.4, .9, .2, .5), fillcolor=(.2, .2, .3, .9))
        p = Plot(x=x, y=y).add(m).plot()
        line, = p._figure.axes[0].get_lines()
        assert same_color(line.get_color(), m.color)
        assert same_color(line.get_markeredgecolor(), m.color)
        assert same_color(line.get_markerfacecolor(), m.fillcolor)

    def test_color_and_alpha(self):

        x = y = [1, 2, 3]
        m = Path(color=(.4, .9, .2), fillcolor=(.2, .2, .3), alpha=.5)
        p = Plot(x=x, y=y).add(m).plot()
        line, = p._figure.axes[0].get_lines()
        assert same_color(line.get_color(), to_rgba(m.color, m.alpha))
        assert same_color(line.get_markeredgecolor(), to_rgba(m.color, m.alpha))
        assert same_color(line.get_markerfacecolor(), to_rgba(m.fillcolor, m.alpha))

    def test_other_props_direct(self):

        x = y = [1, 2, 3]
        m = Path(marker="s", linestyle="--", linewidth=3, pointsize=10, edgewidth=1)
        p = Plot(x=x, y=y).add(m).plot()
        line, = p._figure.axes[0].get_lines()
        assert line.get_marker() == m.marker
        assert line.get_linestyle() == m.linestyle
        assert line.get_linewidth() == m.linewidth
        assert line.get_markersize() == m.pointsize
        assert line.get_markeredgewidth() == m.edgewidth

    def test_other_props_mapped(self):

        x = y = [1, 2, 3, 4]
        g = ["a", "a", "b", "b"]
        m = Path()
        p = Plot(x=x, y=y, marker=g, linestyle=g, pointsize=g).add(m).plot()
        line1, line2 = p._figure.axes[0].get_lines()
        assert line1.get_marker() != line2.get_marker()
        # Matplotlib bug in storing linestyle from dash pattern
        # assert line1.get_linestyle() != line2.get_linestyle()
        assert line1.get_markersize() != line2.get_markersize()

    def test_capstyle(self):

        x = y = [1, 2]
        rc = {"lines.solid_capstyle": "projecting", "lines.dash_capstyle": "round"}

        p = Plot(x, y).add(Path()).theme(rc).plot()
        line, = p._figure.axes[0].get_lines()
        assert line.get_dash_capstyle() == "projecting"

        p = Plot(x, y).add(Path(linestyle="--")).theme(rc).plot()
        line, = p._figure.axes[0].get_lines()
        assert line.get_dash_capstyle() == "round"

        p = Plot(x, y).add(Path({"solid_capstyle": "butt"})).theme(rc).plot()
        line, = p._figure.axes[0].get_lines()
        assert line.get_solid_capstyle() == "butt"


class TestLine:

    # Most behaviors shared with Path and covered by above tests

    def test_xy_data(self):

        x = [1, 5, 3, np.nan, 2]
        y = [1, 4, 2, 5, 3]
        g = [1, 2, 1, 1, 2]
        p = Plot(x=x, y=y, group=g).add(Line()).plot()
        line1, line2 = p._figure.axes[0].get_lines()

        assert_array_equal(line1.get_xdata(), [1, 3])
        assert_array_equal(line1.get_ydata(), [1, 2])
        assert_array_equal(line2.get_xdata(), [2, 5])
        assert_array_equal(line2.get_ydata(), [3, 4])


class TestPaths:

    def test_xy_data(self):

        x = [1, 5, 3, np.nan, 2]
        y = [1, 4, 2, 5, 3]
        g = [1, 2, 1, 1, 2]
        p = Plot(x=x, y=y, group=g).add(Paths()).plot()
        lines, = p._figure.axes[0].collections

        verts = lines.get_paths()[0].vertices.T
        assert_array_equal(verts[0], [1, 3, np.nan])
        assert_array_equal(verts[1], [1, 2, np.nan])

        verts = lines.get_paths()[1].vertices.T
        assert_array_equal(verts[0], [5, 2])
        assert_array_equal(verts[1], [4, 3])

    def test_set_properties(self):

        x = y = [1, 2, 3]
        m = Paths(color=".737", linewidth=1, linestyle=(3, 1))
        p = Plot(x=x, y=y).add(m).plot()
        lines, = p._figure.axes[0].collections

        assert same_color(lines.get_color().squeeze(), m.color)
        assert lines.get_linewidth().item() == m.linewidth
        assert lines.get_dashes()[0] == (0, list(m.linestyle))

    def test_mapped_properties(self):

        x = y = [1, 2, 3, 4]
        g = ["a", "a", "b", "b"]
        p = Plot(x=x, y=y, color=g, linewidth=g, linestyle=g).add(Paths()).plot()
        lines, = p._figure.axes[0].collections

        assert not np.array_equal(lines.get_colors()[0], lines.get_colors()[1])
        assert lines.get_linewidths()[0] != lines.get_linewidth()[1]
        assert lines.get_linestyle()[0] != lines.get_linestyle()[1]

    def test_color_with_alpha(self):

        x = y = [1, 2, 3]
        m = Paths(color=(.2, .6, .9, .5))
        p = Plot(x=x, y=y).add(m).plot()
        lines, = p._figure.axes[0].collections
        assert same_color(lines.get_colors().squeeze(), m.color)

    def test_color_and_alpha(self):

        x = y = [1, 2, 3]
        m = Paths(color=(.2, .6, .9), alpha=.5)
        p = Plot(x=x, y=y).add(m).plot()
        lines, = p._figure.axes[0].collections
        assert same_color(lines.get_colors().squeeze(), to_rgba(m.color, m.alpha))

    def test_capstyle(self):

        x = y = [1, 2]
        rc = {"lines.solid_capstyle": "projecting"}

        with mpl.rc_context(rc):
            p = Plot(x, y).add(Paths()).plot()
            lines = p._figure.axes[0].collections[0]
            assert lines.get_capstyle() == "projecting"

            p = Plot(x, y).add(Paths(linestyle="--")).plot()
            lines = p._figure.axes[0].collections[0]
            assert lines.get_capstyle() == "projecting"

            p = Plot(x, y).add(Paths({"capstyle": "butt"})).plot()
            lines = p._figure.axes[0].collections[0]
            assert lines.get_capstyle() == "butt"


class TestLines:

    def test_xy_data(self):

        x = [1, 5, 3, np.nan, 2]
        y = [1, 4, 2, 5, 3]
        g = [1, 2, 1, 1, 2]
        p = Plot(x=x, y=y, group=g).add(Lines()).plot()
        lines, = p._figure.axes[0].collections

        verts = lines.get_paths()[0].vertices.T
        assert_array_equal(verts[0], [1, 3])
        assert_array_equal(verts[1], [1, 2])

        verts = lines.get_paths()[1].vertices.T
        assert_array_equal(verts[0], [2, 5])
        assert_array_equal(verts[1], [3, 4])

    def test_single_orient_value(self):

        x = [1, 1, 1]
        y = [1, 2, 3]
        p = Plot(x, y).add(Lines()).plot()
        lines, = p._figure.axes[0].collections
        verts = lines.get_paths()[0].vertices.T
        assert_array_equal(verts[0], x)
        assert_array_equal(verts[1], y)


class TestRange:

    def test_xy_data(self):

        x = [1, 2]
        ymin = [1, 4]
        ymax = [2, 3]

        p = Plot(x=x, ymin=ymin, ymax=ymax).add(Range()).plot()
        lines, = p._figure.axes[0].collections

        for i, path in enumerate(lines.get_paths()):
            verts = path.vertices.T
            assert_array_equal(verts[0], [x[i], x[i]])
            assert_array_equal(verts[1], [ymin[i], ymax[i]])

    def test_auto_range(self):

        x = [1, 1, 2, 2, 2]
        y = [1, 2, 3, 4, 5]

        p = Plot(x=x, y=y).add(Range()).plot()
        lines, = p._figure.axes[0].collections
        paths = lines.get_paths()
        assert_array_equal(paths[0].vertices, [(1, 1), (1, 2)])
        assert_array_equal(paths[1].vertices, [(2, 3), (2, 5)])

    def test_mapped_color(self):

        x = [1, 2, 1, 2]
        ymin = [1, 4, 3, 2]
        ymax = [2, 3, 1, 4]
        group = ["a", "a", "b", "b"]

        p = Plot(x=x, ymin=ymin, ymax=ymax, color=group).add(Range()).plot()
        lines, = p._figure.axes[0].collections
        colors = p._theme["axes.prop_cycle"].by_key()["color"]

        for i, path in enumerate(lines.get_paths()):
            verts = path.vertices.T
            assert_array_equal(verts[0], [x[i], x[i]])
            assert_array_equal(verts[1], [ymin[i], ymax[i]])
            assert same_color(lines.get_colors()[i], colors[i // 2])

    def test_direct_properties(self):

        x = [1, 2]
        ymin = [1, 4]
        ymax = [2, 3]

        m = Range(color=".654", linewidth=4)
        p = Plot(x=x, ymin=ymin, ymax=ymax).add(m).plot()
        lines, = p._figure.axes[0].collections

        for i, path in enumerate(lines.get_paths()):
            assert same_color(lines.get_colors()[i], m.color)
            assert lines.get_linewidths()[i] == m.linewidth


class TestDash:

    def test_xy_data(self):

        x = [0, 0, 1, 2]
        y = [1, 2, 3, 4]

        p = Plot(x=x, y=y).add(Dash()).plot()
        lines, = p._figure.axes[0].collections

        for i, path in enumerate(lines.get_paths()):
            verts = path.vertices.T
            assert_array_almost_equal(verts[0], [x[i] - .4, x[i] + .4])
            assert_array_equal(verts[1], [y[i], y[i]])

    def test_xy_data_grouped(self):

        x = [0, 0, 1, 2]
        y = [1, 2, 3, 4]
        color = ["a", "b", "a", "b"]

        p = Plot(x=x, y=y, color=color).add(Dash()).plot()
        lines, = p._figure.axes[0].collections

        idx = [0, 2, 1, 3]
        for i, path in zip(idx, lines.get_paths()):
            verts = path.vertices.T
            assert_array_almost_equal(verts[0], [x[i] - .4, x[i] + .4])
            assert_array_equal(verts[1], [y[i], y[i]])

    def test_set_properties(self):

        x = [0, 0, 1, 2]
        y = [1, 2, 3, 4]

        m = Dash(color=".8", linewidth=4)
        p = Plot(x=x, y=y).add(m).plot()
        lines, = p._figure.axes[0].collections

        for color in lines.get_color():
            assert same_color(color, m.color)
        for linewidth in lines.get_linewidth():
            assert linewidth == m.linewidth

    def test_mapped_properties(self):

        x = [0, 1]
        y = [1, 2]
        color = ["a", "b"]
        linewidth = [1, 2]

        p = Plot(x=x, y=y, color=color, linewidth=linewidth).add(Dash()).plot()
        lines, = p._figure.axes[0].collections
        palette = p._theme["axes.prop_cycle"].by_key()["color"]

        for color, line_color in zip(palette, lines.get_color()):
            assert same_color(color, line_color)

        linewidths = lines.get_linewidths()
        assert linewidths[1] > linewidths[0]

    def test_width(self):

        x = [0, 0, 1, 2]
        y = [1, 2, 3, 4]

        p = Plot(x=x, y=y).add(Dash(width=.4)).plot()
        lines, = p._figure.axes[0].collections

        for i, path in enumerate(lines.get_paths()):
            verts = path.vertices.T
            assert_array_almost_equal(verts[0], [x[i] - .2, x[i] + .2])
            assert_array_equal(verts[1], [y[i], y[i]])

    def test_dodge(self):

        x = [0, 1]
        y = [1, 2]
        group = ["a", "b"]

        p = Plot(x=x, y=y, group=group).add(Dash(), Dodge()).plot()
        lines, = p._figure.axes[0].collections

        paths = lines.get_paths()

        v0 = paths[0].vertices.T
        assert_array_almost_equal(v0[0], [-.4, 0])
        assert_array_equal(v0[1], [y[0], y[0]])

        v1 = paths[1].vertices.T
        assert_array_almost_equal(v1[0], [1, 1.4])
        assert_array_equal(v1[1], [y[1], y[1]])


================================================
FILE: tests/_marks/test_text.py
================================================

import numpy as np
from matplotlib.colors import to_rgba
from matplotlib.text import Text as MPLText

from numpy.testing import assert_array_almost_equal

from seaborn._core.plot import Plot
from seaborn._marks.text import Text


class TestText:

    def get_texts(self, ax):
        if ax.texts:
            return list(ax.texts)
        else:
            # Compatibility with matplotlib < 3.5 (I think)
            return [a for a in ax.artists if isinstance(a, MPLText)]

    def test_simple(self):

        x = y = [1, 2, 3]
        s = list("abc")

        p = Plot(x, y, text=s).add(Text()).plot()
        ax = p._figure.axes[0]
        for i, text in enumerate(self.get_texts(ax)):
            x_, y_ = text.get_position()
            assert x_ == x[i]
            assert y_ == y[i]
            assert text.get_text() == s[i]
            assert text.get_horizontalalignment() == "center"
            assert text.get_verticalalignment() == "center_baseline"

    def test_set_properties(self):

        x = y = [1, 2, 3]
        s = list("abc")
        color = "red"
        alpha = .6
        fontsize = 6
        valign = "bottom"

        m = Text(color=color, alpha=alpha, fontsize=fontsize, valign=valign)
        p = Plot(x, y, text=s).add(m).plot()
        ax = p._figure.axes[0]
        for i, text in enumerate(self.get_texts(ax)):
            assert text.get_text() == s[i]
            assert text.get_color() == to_rgba(m.color, m.alpha)
            assert text.get_fontsize() == m.fontsize
            assert text.get_verticalalignment() == m.valign

    def test_mapped_properties(self):

        x = y = [1, 2, 3]
        s = list("abc")
        color = list("aab")
        fontsize = [1, 2, 4]

        p = Plot(x, y, color=color, fontsize=fontsize, text=s).add(Text()).plot()
        ax = p._figure.axes[0]
        texts = self.get_texts(ax)
        assert texts[0].get_color() == texts[1].get_color()
        assert texts[0].get_color() != texts[2].get_color()
        assert (
            texts[0].get_fontsize()
            < texts[1].get_fontsize()
            < texts[2].get_fontsize()
        )

    def test_mapped_alignment(self):

        x = [1, 2]
        p = Plot(x=x, y=x, halign=x, valign=x, text=x).add(Text()).plot()
        ax = p._figure.axes[0]
        t1, t2 = self.get_texts(ax)
        assert t1.get_horizontalalignment() == "left"
        assert t2.get_horizontalalignment() == "right"
        assert t1.get_verticalalignment() == "top"
        assert t2.get_verticalalignment() == "bottom"

    def test_identity_fontsize(self):

        x = y = [1, 2, 3]
        s = list("abc")
        fs = [5, 8, 12]
        p = Plot(x, y, text=s, fontsize=fs).add(Text()).scale(fontsize=None).plot()
        ax = p._figure.axes[0]
        for i, text in enumerate(self.get_texts(ax)):
            assert text.get_fontsize() == fs[i]

    def test_offset_centered(self):

        x = y = [1, 2, 3]
        s = list("abc")
        p = Plot(x, y, text=s).add(Text()).plot()
        ax = p._figure.axes[0]
        ax_trans = ax.transData.get_matrix()
        for text in self.get_texts(ax):
            assert_array_almost_equal(text.get_transform().get_matrix(), ax_trans)

    def test_offset_valign(self):

        x = y = [1, 2, 3]
        s = list("abc")
        m = Text(valign="bottom", fontsize=5, offset=.1)
        p = Plot(x, y, text=s).add(m).plot()
        ax = p._figure.axes[0]
        expected_shift_matrix = np.zeros((3, 3))
        expected_shift_matrix[1, -1] = m.offset * ax.figure.dpi / 72
        ax_trans = ax.transData.get_matrix()
        for text in self.get_texts(ax):
            shift_matrix = text.get_transform().get_matrix() - ax_trans
            assert_array_almost_equal(shift_matrix, expected_shift_matrix)

    def test_offset_halign(self):

        x = y = [1, 2, 3]
        s = list("abc")
        m = Text(halign="right", fontsize=10, offset=.5)
        p = Plot(x, y, text=s).add(m).plot()
        ax = p._figure.axes[0]
        expected_shift_matrix = np.zeros((3, 3))
        expected_shift_matrix[0, -1] = -m.offset * ax.figure.dpi / 72
        ax_trans = ax.transData.get_matrix()
        for text in self.get_texts(ax):
            shift_matrix = text.get_transform().get_matrix() - ax_trans
            assert_array_almost_equal(shift_matrix, expected_shift_matrix)


================================================
FILE: tests/_stats/__init__.py
================================================


================================================
FILE: tests/_stats/test_aggregation.py
================================================

import numpy as np
import pandas as pd

import pytest
from pandas.testing import assert_frame_equal

from seaborn._core.groupby import GroupBy
from seaborn._stats.aggregation import Agg, Est


class AggregationFixtures:

    @pytest.fixture
    def df(self, rng):

        n = 30
        return pd.DataFrame(dict(
            x=rng.uniform(0, 7, n).round(),
            y=rng.normal(size=n),
            color=rng.choice(["a", "b", "c"], n),
            group=rng.choice(["x", "y"], n),
        ))

    def get_groupby(self, df, orient):

        other = {"x": "y", "y": "x"}[orient]
        cols = [c for c in df if c != other]
        return GroupBy(cols)


class TestAgg(AggregationFixtures):

    def test_default(self, df):

        ori = "x"
        df = df[["x", "y"]]
        gb = self.get_groupby(df, ori)
        res = Agg()(df, gb, ori, {})

        expected = df.groupby("x", as_index=False)["y"].mean()
        assert_frame_equal(res, expected)

    def test_default_multi(self, df):

        ori = "x"
        gb = self.get_groupby(df, ori)
        res = Agg()(df, gb, ori, {})

        grp = ["x", "color", "group"]
        index = pd.MultiIndex.from_product(
            [sorted(df["x"].unique()), df["color"].unique(), df["group"].unique()],
            names=["x", "color", "group"]
        )
        expected = (
            df
            .groupby(grp)
            .agg("mean")
            .reindex(index=index)
            .dropna()
            .reset_index()
            .reindex(columns=df.columns)
        )
        assert_frame_equal(res, expected)

    @pytest.mark.parametrize("func", ["max", lambda x: float(len(x) % 2)])
    def test_func(self, df, func):

        ori = "x"
        df = df[["x", "y"]]
        gb = self.get_groupby(df, ori)
        res = Agg(func)(df, gb, ori, {})

        expected = df.groupby("x", as_index=False)["y"].agg(func)
        assert_frame_equal(res, expected)


class TestEst(AggregationFixtures):

    # Note: Most of the underlying code is exercised in tests/test_statistics

    @pytest.mark.parametrize("func", [np.mean, "mean"])
    def test_mean_sd(self, df, func):

        ori = "x"
        df = df[["x", "y"]]
        gb = self.get_groupby(df, ori)
        res = Est(func, "sd")(df, gb, ori, {})

        grouped = df.groupby("x", as_index=False)["y"]
        est = grouped.mean()
        err = grouped.std().fillna(0)  # fillna needed only on pinned tests
        expected = est.assign(ymin=est["y"] - err["y"], ymax=est["y"] + err["y"])
        assert_frame_equal(res, expected)

    def test_sd_single_obs(self):

        y = 1.5
        ori = "x"
        df = pd.DataFrame([{"x": "a", "y": y}])
        gb = self.get_groupby(df, ori)
        res = Est("mean", "sd")(df, gb, ori, {})
        expected = df.assign(ymin=y, ymax=y)
        assert_frame_equal(res, expected)

    def test_median_pi(self, df):

        ori = "x"
        df = df[["x", "y"]]
        gb = self.get_groupby(df, ori)
        res = Est("median", ("pi", 100))(df, gb, ori, {})

        grouped = df.groupby("x", as_index=False)["y"]
        est = grouped.median()
        expected = est.assign(ymin=grouped.min()["y"], ymax=grouped.max()["y"])
        assert_frame_equal(res, expected)

    def test_weighted_mean(self, df, rng):

        weights = rng.uniform(0, 5, len(df))
        gb = self.get_groupby(df[["x", "y"]], "x")
        df = df.assign(weight=weights)
        res = Est("mean")(df, gb, "x", {})
        for _, res_row in res.iterrows():
            rows = df[df["x"] == res_row["x"]]
            expected = np.average(rows["y"], weights=rows["weight"])
            assert res_row["y"] == expected

    def test_seed(self, df):

        ori = "x"
        gb = self.get_groupby(df, ori)
        args = df, gb, ori, {}
        res1 = Est("mean", "ci", seed=99)(*args)
        res2 = Est("mean", "ci", seed=99)(*args)
        assert_frame_equal(res1, res2)


================================================
FILE: tests/_stats/test_counting.py
================================================

import numpy as np
import pandas as pd

import pytest
from numpy.testing import assert_array_equal

from seaborn._core.groupby import GroupBy
from seaborn._stats.counting import Hist, Count


class TestCount:

    @pytest.fixture
    def df(self, rng):

        n = 30
        return pd.DataFrame(dict(
            x=rng.uniform(0, 7, n).round(),
            y=rng.normal(size=n),
            color=rng.choice(["a", "b", "c"], n),
            group=rng.choice(["x", "y"], n),
        ))

    def get_groupby(self, df, orient):

        other = {"x": "y", "y": "x"}[orient]
        cols = [c for c in df if c != other]
        return GroupBy(cols)

    def test_single_grouper(self, df):

        ori = "x"
        df = df[["x"]]
        gb = self.get_groupby(df, ori)
        res = Count()(df, gb, ori, {})
        expected = df.groupby("x").size()
        assert_array_equal(res.sort_values("x")["y"], expected)

    def test_multiple_groupers(self, df):

        ori = "x"
        df = df[["x", "group"]].sort_values("group")
        gb = self.get_groupby(df, ori)
        res = Count()(df, gb, ori, {})
        expected = df.groupby(["x", "group"]).size()
        assert_array_equal(res.sort_values(["x", "group"])["y"], expected)


class TestHist:

    @pytest.fixture
    def single_args(self):

        groupby = GroupBy(["group"])

        class Scale:
            scale_type = "continuous"

        return groupby, "x", {"x": Scale()}

    @pytest.fixture
    def triple_args(self):

        groupby = GroupBy(["group", "a", "s"])

        class Scale:
            scale_type = "continuous"

        return groupby, "x", {"x": Scale()}

    def test_string_bins(self, long_df):

        h = Hist(bins="sqrt")
        bin_kws = h._define_bin_params(long_df, "x", "continuous")
        assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max())
        assert bin_kws["bins"] == int(np.sqrt(len(long_df)))

    def test_int_bins(self, long_df):

        n = 24
        h = Hist(bins=n)
        bin_kws = h._define_bin_params(long_df, "x", "continuous")
        assert bin_kws["range"] == (long_df["x"].min(), long_df["x"].max())
        assert bin_kws["bins"] == n

    def test_array_bins(self, long_df):

        bins = [-3, -2, 1, 2, 3]
        h = Hist(bins=bins)
        bin_kws = h._define_bin_params(long_df, "x", "continuous")
        assert_array_equal(bin_kws["bins"], bins)

    def test_binwidth(self, long_df):

        binwidth = .5
        h = Hist(binwidth=binwidth)
        bin_kws = h._define_bin_params(long_df, "x", "continuous")
        n_bins = bin_kws["bins"]
        left, right = bin_kws["range"]
        assert (right - left) / n_bins == pytest.approx(binwidth)

    def test_binrange(self, long_df):

        binrange = (-4, 4)
        h = Hist(binrange=binrange)
        bin_kws = h._define_bin_params(long_df, "x", "continuous")
        assert bin_kws["range"] == binrange

    def test_discrete_bins(self, long_df):

        h = Hist(discrete=True)
        x = long_df["x"].astype(int)
        bin_kws = h._define_bin_params(long_df.assign(x=x), "x", "continuous")
        assert bin_kws["range"] == (x.min() - .5, x.max() + .5)
        assert bin_kws["bins"] == (x.max() - x.min() + 1)

    def test_discrete_bins_from_nominal_scale(self, rng):

        h = Hist()
        x = rng.randint(0, 5, 10)
        df = pd.DataFrame({"x": x})
        bin_kws = h._define_bin_params(df, "x", "nominal")
        assert bin_kws["range"] == (x.min() - .5, x.max() + .5)
        assert bin_kws["bins"] == (x.max() - x.min() + 1)

    def test_count_stat(self, long_df, single_args):

        h = Hist(stat="count")
        out = h(long_df, *single_args)
        assert out["y"].sum() == len(long_df)

    def test_probability_stat(self, long_df, single_args):

        h = Hist(stat="probability")
        out = h(long_df, *single_args)
        assert out["y"].sum() == 1

    def test_proportion_stat(self, long_df, single_args):

        h = Hist(stat="proportion")
        out = h(long_df, *single_args)
        assert out["y"].sum() == 1

    def test_percent_stat(self, long_df, single_args):

        h = Hist(stat="percent")
        out = h(long_df, *single_args)
        assert out["y"].sum() == 100

    def test_density_stat(self, long_df, single_args):

        h = Hist(stat="density")
        out = h(long_df, *single_args)
        assert (out["y"] * out["space"]).sum() == 1

    def test_frequency_stat(self, long_df, single_args):

        h = Hist(stat="frequency")
        out = h(long_df, *single_args)
        assert (out["y"] * out["space"]).sum() == len(long_df)

    def test_invalid_stat(self):

        with pytest.raises(ValueError, match="The `stat` parameter for `Hist`"):
            Hist(stat="invalid")

    def test_cumulative_count(self, long_df, single_args):

        h = Hist(stat="count", cumulative=True)
        out = h(long_df, *single_args)
        assert out["y"].max() == len(long_df)

    def test_cumulative_proportion(self, long_df, single_args):

        h = Hist(stat="proportion", cumulative=True)
        out = h(long_df, *single_args)
        assert out["y"].max() == 1

    def test_cumulative_density(self, long_df, single_args):

        h = Hist(stat="density", cumulative=True)
        out = h(long_df, *single_args)
        assert out["y"].max() == 1

    def test_common_norm_default(self, long_df, triple_args):

        h = Hist(stat="percent")
        out = h(long_df, *triple_args)
        assert out["y"].sum() == pytest.approx(100)

    def test_common_norm_false(self, long_df, triple_args):

        h = Hist(stat="percent", common_norm=False)
        out = h(long_df, *triple_args)
        for _, out_part in out.groupby(["a", "s"]):
            assert out_part["y"].sum() == pytest.approx(100)

    def test_common_norm_subset(self, long_df, triple_args):

        h = Hist(stat="percent", common_norm=["a"])
        out = h(long_df, *triple_args)
        for _, out_part in out.groupby("a"):
            assert out_part["y"].sum() == pytest.approx(100)

    def test_common_norm_warning(self, long_df, triple_args):

        h = Hist(common_norm=["b"])
        with pytest.warns(UserWarning, match=r"Undefined variable\(s\)"):
            h(long_df, *triple_args)

    def test_common_bins_default(self, long_df, triple_args):

        h = Hist()
        out = h(long_df, *triple_args)
        bins = []
        for _, out_part in out.groupby(["a", "s"]):
            bins.append(tuple(out_part["x"]))
        assert len(set(bins)) == 1

    def test_common_bins_false(self, long_df, triple_args):

        h = Hist(common_bins=False)
        out = h(long_df, *triple_args)
        bins = []
        for _, out_part in out.groupby(["a", "s"]):
            bins.append(tuple(out_part["x"]))
        assert len(set(bins)) == len(out.groupby(["a", "s"]))

    def test_common_bins_subset(self, long_df, triple_args):

        h = Hist(common_bins=False)
        out = h(long_df, *triple_args)
        bins = []
        for _, out_part in out.groupby("a"):
            bins.append(tuple(out_part["x"]))
        assert len(set(bins)) == out["a"].nunique()

    def test_common_bins_warning(self, long_df, triple_args):

        h = Hist(common_bins=["b"])
        with pytest.warns(UserWarning, match=r"Undefined variable\(s\)"):
            h(long_df, *triple_args)

    def test_histogram_single(self, long_df, single_args):

        h = Hist()
        out = h(long_df, *single_args)
        hist, edges = np.histogram(long_df["x"], bins="auto")
        assert_array_equal(out["y"], hist)
        assert_array_equal(out["space"], np.diff(edges))

    def test_histogram_multiple(self, long_df, triple_args):

        h = Hist()
        out = h(long_df, *triple_args)
        bins = np.histogram_bin_edges(long_df["x"], "auto")
        for (a, s), out_part in out.groupby(["a", "s"]):
            x = long_df.loc[(long_df["a"] == a) & (long_df["s"] == s), "x"]
            hist, edges = np.histogram(x, bins=bins)
            assert_array_equal(out_part["y"], hist)
            assert_array_equal(out_part["space"], np.diff(edges))


================================================
FILE: tests/_stats/test_density.py
================================================
import numpy as np
import pandas as pd

import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn._core.groupby import GroupBy
from seaborn._stats.density import KDE, _no_scipy
from seaborn._compat import groupby_apply_include_groups


class TestKDE:

    @pytest.fixture
    def df(self, rng):

        n = 100
        return pd.DataFrame(dict(
            x=rng.uniform(0, 7, n).round(),
            y=rng.normal(size=n),
            color=rng.choice(["a", "b", "c"], n),
            alpha=rng.choice(["x", "y"], n),
        ))

    def get_groupby(self, df, orient):

        cols = [c for c in df if c != orient]
        return GroupBy([*cols, "group"])

    def integrate(self, y, x):
        y = np.asarray(y)
        x = np.asarray(x)
        dx = np.diff(x)
        return (dx * y[:-1] + dx * y[1:]).sum() / 2

    @pytest.mark.parametrize("ori", ["x", "y"])
    def test_columns(self, df, ori):

        df = df[[ori, "alpha"]]
        gb = self.get_groupby(df, ori)
        res = KDE()(df, gb, ori, {})
        other = {"x": "y", "y": "x"}[ori]
        expected = [ori, "alpha", "density", other]
        assert list(res.columns) == expected

    @pytest.mark.parametrize("gridsize", [20, 30, None])
    def test_gridsize(self, df, gridsize):

        ori = "y"
        df = df[[ori]]
        gb = self.get_groupby(df, ori)
        res = KDE(gridsize=gridsize)(df, gb, ori, {})
        if gridsize is None:
            assert_array_equal(res[ori], df[ori])
        else:
            assert len(res) == gridsize

    @pytest.mark.parametrize("cut", [1, 2])
    def test_cut(self, df, cut):

        ori = "y"
        df = df[[ori]]
        gb = self.get_groupby(df, ori)
        res = KDE(cut=cut, bw_method=1)(df, gb, ori, {})

        vals = df[ori]
        bw = vals.std()
        assert res[ori].min() == pytest.approx(vals.min() - bw * cut, abs=1e-2)
        assert res[ori].max() == pytest.approx(vals.max() + bw * cut, abs=1e-2)

    @pytest.mark.parametrize("common_grid", [True, False])
    def test_common_grid(self, df, common_grid):

        ori = "y"
        df = df[[ori, "alpha"]]
        gb = self.get_groupby(df, ori)
        res = KDE(common_grid=common_grid)(df, gb, ori, {})

        vals = df["alpha"].unique()
        a = res.loc[res["alpha"] == vals[0], ori].to_numpy()
        b = res.loc[res["alpha"] == vals[1], ori].to_numpy()
        if common_grid:
            assert_array_equal(a, b)
        else:
            assert np.not_equal(a, b).all()

    @pytest.mark.parametrize("common_norm", [True, False])
    def test_common_norm(self, df, common_norm):

        ori = "y"
        df = df[[ori, "alpha"]]
        gb = self.get_groupby(df, ori)
        res = KDE(common_norm=common_norm)(df, gb, ori, {})

        areas = (
            res.groupby("alpha")
            .apply(
                lambda x: self.integrate(x["density"], x[ori]),
                **groupby_apply_include_groups(False),
            )
        )

        if common_norm:
            assert areas.sum() == pytest.approx(1, abs=1e-3)
        else:
            assert_array_almost_equal(areas, [1, 1], decimal=3)

    def test_common_norm_variables(self, df):

        ori = "y"
        df = df[[ori, "alpha", "color"]]
        gb = self.get_groupby(df, ori)
        res = KDE(common_norm=["alpha"])(df, gb, ori, {})

        def integrate_by_color_and_sum(x):
            return (
                x.groupby("color")
                .apply(
                    lambda y: self.integrate(y["density"], y[ori]),
                    **groupby_apply_include_groups(False)
                )
                .sum()
            )

        areas = (
            res
            .groupby("alpha")
            .apply(integrate_by_color_and_sum, **groupby_apply_include_groups(False))
        )
        assert_array_almost_equal(areas, [1, 1], decimal=3)

    @pytest.mark.parametrize("param", ["norm", "grid"])
    def test_common_input_checks(self, df, param):

        ori = "y"
        df = df[[ori, "alpha"]]
        gb = self.get_groupby(df, ori)
        msg = rf"Undefined variable\(s\) passed for KDE.common_{param}"
        with pytest.warns(UserWarning, match=msg):
            KDE(**{f"common_{param}": ["color", "alpha"]})(df, gb, ori, {})

        msg = f"KDE.common_{param} must be a boolean or list of strings"
        with pytest.raises(TypeError, match=msg):
            KDE(**{f"common_{param}": "alpha"})(df, gb, ori, {})

    def test_bw_adjust(self, df):

        ori = "y"
        df = df[[ori]]
        gb = self.get_groupby(df, ori)
        res1 = KDE(bw_adjust=0.5)(df, gb, ori, {})
        res2 = KDE(bw_adjust=2.0)(df, gb, ori, {})

        mad1 = res1["density"].diff().abs().mean()
        mad2 = res2["density"].diff().abs().mean()
        assert mad1 > mad2

    def test_bw_method_scalar(self, df):

        ori = "y"
        df = df[[ori]]
        gb = self.get_groupby(df, ori)
        res1 = KDE(bw_method=0.5)(df, gb, ori, {})
        res2 = KDE(bw_method=2.0)(df, gb, ori, {})

        mad1 = res1["density"].diff().abs().mean()
        mad2 = res2["density"].diff().abs().mean()
        assert mad1 > mad2

    @pytest.mark.skipif(_no_scipy, reason="KDE.cumulative requires scipy")
    @pytest.mark.parametrize("common_norm", [True, False])
    def test_cumulative(self, df, common_norm):

        ori = "y"
        df = df[[ori, "alpha"]]
        gb = self.get_groupby(df, ori)
        res = KDE(cumulative=True, common_norm=common_norm)(df, gb, ori, {})

        for _, group_res in res.groupby("alpha"):
            assert (group_res["density"].diff().dropna() >= 0).all()
            if not common_norm:
                assert group_res["density"].max() == pytest.approx(1, abs=1e-3)

    def test_cumulative_requires_scipy(self):

        if _no_scipy:
            err = "Cumulative KDE evaluation requires scipy"
            with pytest.raises(RuntimeError, match=err):
                KDE(cumulative=True)

    @pytest.mark.parametrize("vals", [[], [1], [1] * 5, [1929245168.06679] * 18])
    def test_singular(self, df, vals):

        df1 = pd.DataFrame({"y": vals, "alpha": ["z"] * len(vals)})
        gb = self.get_groupby(df1, "y")
        res = KDE()(df1, gb, "y", {})
        assert res.empty

        df2 = pd.concat([df[["y", "alpha"]], df1], ignore_index=True)
        gb = self.get_groupby(df2, "y")
        res = KDE()(df2, gb, "y", {})
        assert set(res["alpha"]) == set(df["alpha"])

    @pytest.mark.parametrize("col", ["y", "weight"])
    def test_missing(self, df, col):

        val, ori = "xy"
        df["weight"] = 1
        df = df[[ori, "weight"]].astype(float)
        df.loc[:4, col] = np.nan
        gb = self.get_groupby(df, ori)
        res = KDE()(df, gb, ori, {})
        assert self.integrate(res[val], res[ori]) == pytest.approx(1, abs=1e-3)


================================================
FILE: tests/_stats/test_order.py
================================================

import numpy as np
import pandas as pd

import pytest
from numpy.testing import assert_array_equal

from seaborn._core.groupby import GroupBy
from seaborn._stats.order import Perc
from seaborn.utils import _version_predates


class Fixtures:

    @pytest.fixture
    def df(self, rng):
        return pd.DataFrame(dict(x="", y=rng.normal(size=30)))

    def get_groupby(self, df, orient):
        # TODO note, copied from aggregation
        other = {"x": "y", "y": "x"}[orient]
        cols = [c for c in df if c != other]
        return GroupBy(cols)


class TestPerc(Fixtures):

    def test_int_k(self, df):

        ori = "x"
        gb = self.get_groupby(df, ori)
        res = Perc(3)(df, gb, ori, {})
        percentiles = [0, 50, 100]
        assert_array_equal(res["percentile"], percentiles)
        assert_array_equal(res["y"], np.percentile(df["y"], percentiles))

    def test_list_k(self, df):

        ori = "x"
        gb = self.get_groupby(df, ori)
        percentiles = [0, 20, 100]
        res = Perc(k=percentiles)(df, gb, ori, {})
        assert_array_equal(res["percentile"], percentiles)
        assert_array_equal(res["y"], np.percentile(df["y"], percentiles))

    def test_orientation(self, df):

        df = df.rename(columns={"x": "y", "y": "x"})
        ori = "y"
        gb = self.get_groupby(df, ori)
        res = Perc(k=3)(df, gb, ori, {})
        assert_array_equal(res["x"], np.percentile(df["x"], [0, 50, 100]))

    def test_method(self, df):

        ori = "x"
        gb = self.get_groupby(df, ori)
        method = "nearest"
        res = Perc(k=5, method=method)(df, gb, ori, {})
        percentiles = [0, 25, 50, 75, 100]
        if _version_predates(np, "1.22.0"):
            expected = np.percentile(df["y"], percentiles, interpolation=method)
        else:
            expected = np.percentile(df["y"], percentiles, method=method)
        assert_array_equal(res["y"], expected)

    def test_grouped(self, df, rng):

        ori = "x"
        df = df.assign(x=rng.choice(["a", "b", "c"], len(df)))
        gb = self.get_groupby(df, ori)
        k = [10, 90]
        res = Perc(k)(df, gb, ori, {})
        for x, res_x in res.groupby("x"):
            assert_array_equal(res_x["percentile"], k)
            expected = np.percentile(df.loc[df["x"] == x, "y"], k)
            assert_array_equal(res_x["y"], expected)

    def test_with_na(self, df):

        ori = "x"
        df.loc[:5, "y"] = np.nan
        gb = self.get_groupby(df, ori)
        k = [10, 90]
        res = Perc(k)(df, gb, ori, {})
        expected = np.percentile(df["y"].dropna(), k)
        assert_array_equal(res["y"], expected)


================================================
FILE: tests/_stats/test_regression.py
================================================

import numpy as np
import pandas as pd

import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from pandas.testing import assert_frame_equal

from seaborn._core.groupby import GroupBy
from seaborn._stats.regression import PolyFit


class TestPolyFit:

    @pytest.fixture
    def df(self, rng):

        n = 100
        return pd.DataFrame(dict(
            x=rng.normal(0, 1, n),
            y=rng.normal(0, 1, n),
            color=rng.choice(["a", "b", "c"], n),
            group=rng.choice(["x", "y"], n),
        ))

    def test_no_grouper(self, df):

        groupby = GroupBy(["group"])
        res = PolyFit(order=1, gridsize=100)(df[["x", "y"]], groupby, "x", {})

        assert_array_equal(res.columns, ["x", "y"])

        grid = np.linspace(df["x"].min(), df["x"].max(), 100)
        assert_array_equal(res["x"], grid)
        assert_array_almost_equal(
            res["y"].diff().diff().dropna(), np.zeros(grid.size - 2)
        )

    def test_one_grouper(self, df):

        groupby = GroupBy(["group"])
        gridsize = 50
        res = PolyFit(gridsize=gridsize)(df, groupby, "x", {})

        assert res.columns.to_list() == ["x", "y", "group"]

        ngroups = df["group"].nunique()
        assert_array_equal(res.index, np.arange(ngroups * gridsize))

        for _, part in res.groupby("group"):
            grid = np.linspace(part["x"].min(), part["x"].max(), gridsize)
            assert_array_equal(part["x"], grid)
            assert part["y"].diff().diff().dropna().abs().gt(0).all()

    def test_missing_data(self, df):

        groupby = GroupBy(["group"])
        df.iloc[5:10] = np.nan
        res1 = PolyFit()(df[["x", "y"]], groupby, "x", {})
        res2 = PolyFit()(df[["x", "y"]].dropna(), groupby, "x", {})
        assert_frame_equal(res1, res2)


================================================
FILE: tests/conftest.py
================================================
import numpy as np
import pandas as pd

import pytest


@pytest.fixture(autouse=True)
def close_figs():
    yield
    import matplotlib.pyplot as plt
    plt.close("all")


@pytest.fixture(autouse=True)
def random_seed():
    seed = sum(map(ord, "seaborn random global"))
    np.random.seed(seed)


@pytest.fixture()
def rng():
    seed = sum(map(ord, "seaborn random object"))
    return np.random.RandomState(seed)


@pytest.fixture
def wide_df(rng):

    columns = list("abc")
    index = pd.RangeIndex(10, 50, 2, name="wide_index")
    values = rng.normal(size=(len(index), len(columns)))
    return pd.DataFrame(values, index=index, columns=columns)


@pytest.fixture
def wide_array(wide_df):

    return wide_df.to_numpy()


# TODO s/flat/thin?
@pytest.fixture
def flat_series(rng):

    index = pd.RangeIndex(10, 30, name="t")
    return pd.Series(rng.normal(size=20), index, name="s")


@pytest.fixture
def flat_array(flat_series):

    return flat_series.to_numpy()


@pytest.fixture
def flat_list(flat_series):

    return flat_series.to_list()


@pytest.fixture(params=["series", "array", "list"])
def flat_data(rng, request):

    index = pd.RangeIndex(10, 30, name="t")
    series = pd.Series(rng.normal(size=20), index, name="s")
    if request.param == "series":
        data = series
    elif request.param == "array":
        data = series.to_numpy()
    elif request.param == "list":
        data = series.to_list()
    return data


@pytest.fixture
def wide_list_of_series(rng):

    return [pd.Series(rng.normal(size=20), np.arange(20), name="a"),
            pd.Series(rng.normal(size=10), np.arange(5, 15), name="b")]


@pytest.fixture
def wide_list_of_arrays(wide_list_of_series):

    return [s.to_numpy() for s in wide_list_of_series]


@pytest.fixture
def wide_list_of_lists(wide_list_of_series):

    return [s.to_list() for s in wide_list_of_series]


@pytest.fixture
def wide_dict_of_series(wide_list_of_series):

    return {s.name: s for s in wide_list_of_series}


@pytest.fixture
def wide_dict_of_arrays(wide_list_of_series):

    return {s.name: s.to_numpy() for s in wide_list_of_series}


@pytest.fixture
def wide_dict_of_lists(wide_list_of_series):

    return {s.name: s.to_list() for s in wide_list_of_series}


@pytest.fixture
def long_df(rng):

    n = 100
    df = pd.DataFrame(dict(
        x=rng.uniform(0, 20, n).round().astype("int"),
        y=rng.normal(size=n),
        z=rng.lognormal(size=n),
        a=rng.choice(list("abc"), n),
        b=rng.choice(list("mnop"), n),
        c=rng.choice([0, 1], n, [.3, .7]),
        d=rng.choice(np.arange("2004-07-30", "2007-07-30", dtype="datetime64[Y]"), n),
        t=rng.choice(np.arange("2004-07-30", "2004-07-31", dtype="datetime64[m]"), n),
        s=rng.choice([2, 4, 8], n),
        f=rng.choice([0.2, 0.3], n),
    ))

    a_cat = df["a"].astype("category")
    new_categories = np.roll(a_cat.cat.categories, 1)
    df["a_cat"] = a_cat.cat.reorder_categories(new_categories)

    df["s_cat"] = df["s"].astype("category")
    df["s_str"] = df["s"].astype(str)

    return df


@pytest.fixture
def long_dict(long_df):

    return long_df.to_dict()


@pytest.fixture
def repeated_df(rng):

    n = 100
    return pd.DataFrame(dict(
        x=np.tile(np.arange(n // 2), 2),
        y=rng.normal(size=n),
        a=rng.choice(list("abc"), n),
        u=np.repeat(np.arange(2), n // 2),
    ))


@pytest.fixture
def null_df(rng, long_df):

    df = long_df.copy()
    for col in df:
        if pd.api.types.is_integer_dtype(df[col]):
            df[col] = df[col].astype(float)
        idx = rng.permutation(df.index)[:10]
        df.loc[idx, col] = np.nan
    return df


@pytest.fixture
def object_df(rng, long_df):

    df = long_df.copy()
    # objectify numeric columns
    for col in ["c", "s", "f"]:
        df[col] = df[col].astype(object)
    return df


@pytest.fixture
def null_series(flat_series):

    return pd.Series(index=flat_series.index, dtype='float64')


class MockConvertibleDataFrame:
    # Mock object that is not a pandas.DataFrame but that can
    # be converted to one via the DataFrame exchange protocol
    def __init__(self, data):
        self._data = data

    def to_pandas(self, *args, **kwargs):
        if self._data is None:
            raise ValueError("Cannot convert to pandas")
        return self._data


@pytest.fixture
def mock_long_df(long_df):

    return MockConvertibleDataFrame(long_df)


================================================
FILE: tests/test_algorithms.py
================================================
import numpy as np

import pytest
from numpy.testing import assert_array_equal

from seaborn import algorithms as algo


@pytest.fixture
def random():
    np.random.seed(sum(map(ord, "test_algorithms")))


def test_bootstrap(random):
    """Test that bootstrapping gives the right answer in dumb cases."""
    a_ones = np.ones(10)
    n_boot = 5
    out1 = algo.bootstrap(a_ones, n_boot=n_boot)
    assert_array_equal(out1, np.ones(n_boot))
    out2 = algo.bootstrap(a_ones, n_boot=n_boot, func=np.median)
    assert_array_equal(out2, np.ones(n_boot))


def test_bootstrap_length(random):
    """Test that we get a bootstrap array of the right shape."""
    a_norm = np.random.randn(1000)
    out = algo.bootstrap(a_norm)
    assert len(out) == 10000

    n_boot = 100
    out = algo.bootstrap(a_norm, n_boot=n_boot)
    assert len(out) == n_boot


def test_bootstrap_range(random):
    """Test that bootstrapping a random array stays within the right range."""
    a_norm = np.random.randn(1000)
    amin, amax = a_norm.min(), a_norm.max()
    out = algo.bootstrap(a_norm)
    assert amin <= out.min()
    assert amax >= out.max()


def test_bootstrap_multiarg(random):
    """Test that bootstrap works with multiple input arrays."""
    x = np.vstack([[1, 10] for i in range(10)])
    y = np.vstack([[5, 5] for i in range(10)])

    def f(x, y):
        return np.vstack((x, y)).max(axis=0)

    out_actual = algo.bootstrap(x, y, n_boot=2, func=f)
    out_wanted = np.array([[5, 10], [5, 10]])
    assert_array_equal(out_actual, out_wanted)


def test_bootstrap_axis(random):
    """Test axis kwarg to bootstrap function."""
    x = np.random.randn(10, 20)
    n_boot = 100

    out_default = algo.bootstrap(x, n_boot=n_boot)
    assert out_default.shape == (n_boot,)

    out_axis = algo.bootstrap(x, n_boot=n_boot, axis=0)
    assert out_axis.shape, (n_boot, x.shape[1])


def test_bootstrap_seed(random):
    """Test that we can get reproducible resamples by seeding the RNG."""
    data = np.random.randn(50)
    seed = 42
    boots1 = algo.bootstrap(data, seed=seed)
    boots2 = algo.bootstrap(data, seed=seed)
    assert_array_equal(boots1, boots2)


def test_bootstrap_ols(random):
    """Test bootstrap of OLS model fit."""
    def ols_fit(X, y):
        XtXinv = np.linalg.inv(np.dot(X.T, X))
        return XtXinv.dot(X.T).dot(y)

    X = np.column_stack((np.random.randn(50, 4), np.ones(50)))
    w = [2, 4, 0, 3, 5]
    y_noisy = np.dot(X, w) + np.random.randn(50) * 20
    y_lownoise = np.dot(X, w) + np.random.randn(50)

    n_boot = 500
    w_boot_noisy = algo.bootstrap(X, y_noisy,
                                  n_boot=n_boot,
                                  func=ols_fit)
    w_boot_lownoise = algo.bootstrap(X, y_lownoise,
                                     n_boot=n_boot,
                                     func=ols_fit)

    assert w_boot_noisy.shape == (n_boot, 5)
    assert w_boot_lownoise.shape == (n_boot, 5)
    assert w_boot_noisy.std() > w_boot_lownoise.std()


def test_bootstrap_units(random):
    """Test that results make sense when passing unit IDs to bootstrap."""
    data = np.random.randn(50)
    ids = np.repeat(range(10), 5)
    bwerr = np.random.normal(0, 2, 10)
    bwerr = bwerr[ids]
    data_rm = data + bwerr
    seed = 77

    boots_orig = algo.bootstrap(data_rm, seed=seed)
    boots_rm = algo.bootstrap(data_rm, units=ids, seed=seed)
    assert boots_rm.std() > boots_orig.std()


def test_bootstrap_arglength():
    """Test that different length args raise ValueError."""
    with pytest.raises(ValueError):
        algo.bootstrap(np.arange(5), np.arange(10))


def test_bootstrap_string_func():
    """Test that named numpy methods are the same as the numpy function."""
    x = np.random.randn(100)

    res_a = algo.bootstrap(x, func="mean", seed=0)
    res_b = algo.bootstrap(x, func=np.mean, seed=0)
    assert np.array_equal(res_a, res_b)

    res_a = algo.bootstrap(x, func="std", seed=0)
    res_b = algo.bootstrap(x, func=np.std, seed=0)
    assert np.array_equal(res_a, res_b)

    with pytest.raises(AttributeError):
        algo.bootstrap(x, func="not_a_method_name")


def test_bootstrap_reproducibility(random):
    """Test that bootstrapping uses the internal random state."""
    data = np.random.randn(50)
    boots1 = algo.bootstrap(data, seed=100)
    boots2 = algo.bootstrap(data, seed=100)
    assert_array_equal(boots1, boots2)

    random_state1 = np.random.RandomState(200)
    boots1 = algo.bootstrap(data, seed=random_state1)
    random_state2 = np.random.RandomState(200)
    boots2 = algo.bootstrap(data, seed=random_state2)
    assert_array_equal(boots1, boots2)

    with pytest.warns(UserWarning):
        # Deprecated, remove when removing random_seed
        boots1 = algo.bootstrap(data, random_seed=100)
        boots2 = algo.bootstrap(data, random_seed=100)
        assert_array_equal(boots1, boots2)


def test_nanaware_func_auto(random):

    x = np.random.normal(size=10)
    x[0] = np.nan
    boots = algo.bootstrap(x, func="mean")
    assert not np.isnan(boots).any()


def test_nanaware_func_warning(random):

    x = np.random.normal(size=10)
    x[0] = np.nan
    with pytest.warns(UserWarning, match="Data contain nans but"):
        boots = algo.bootstrap(x, func="ptp")
    assert np.isnan(boots).any()


================================================
FILE: tests/test_axisgrid.py
================================================
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

import pytest
import numpy.testing as npt
from numpy.testing import assert_array_equal, assert_array_almost_equal
import pandas.testing as tm

from seaborn._base import categorical_order
from seaborn import rcmod
from seaborn.palettes import color_palette
from seaborn.relational import scatterplot
from seaborn.distributions import histplot, kdeplot, distplot
from seaborn.categorical import pointplot
from seaborn.utils import _version_predates
from seaborn import axisgrid as ag
from seaborn._testing import (
    assert_plots_equal,
    assert_colors_equal,
)
from seaborn._compat import get_legend_handles

rs = np.random.RandomState(0)


class TestFacetGrid:

    df = pd.DataFrame(dict(x=rs.normal(size=60),
                           y=rs.gamma(4, size=60),
                           a=np.repeat(list("abc"), 20),
                           b=np.tile(list("mn"), 30),
                           c=np.tile(list("tuv"), 20),
                           d=np.tile(list("abcdefghijkl"), 5)))

    def test_self_data(self):

        g = ag.FacetGrid(self.df)
        assert g.data is self.df

    def test_self_figure(self):

        g = ag.FacetGrid(self.df)
        assert isinstance(g.figure, plt.Figure)
        assert g.figure is g._figure

    def test_self_axes(self):

        g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
        for ax in g.axes.flat:
            assert isinstance(ax, plt.Axes)

    def test_axes_array_size(self):

        g = ag.FacetGrid(self.df)
        assert g.axes.shape == (1, 1)

        g = ag.FacetGrid(self.df, row="a")
        assert g.axes.shape == (3, 1)

        g = ag.FacetGrid(self.df, col="b")
        assert g.axes.shape == (1, 2)

        g = ag.FacetGrid(self.df, hue="c")
        assert g.axes.shape == (1, 1)

        g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
        assert g.axes.shape == (3, 2)
        for ax in g.axes.flat:
            assert isinstance(ax, plt.Axes)

    def test_single_axes(self):

        g = ag.FacetGrid(self.df)
        assert isinstance(g.ax, plt.Axes)

        g = ag.FacetGrid(self.df, row="a")
        with pytest.raises(AttributeError):
            g.ax

        g = ag.FacetGrid(self.df, col="a")
        with pytest.raises(AttributeError):
            g.ax

        g = ag.FacetGrid(self.df, col="a", row="b")
        with pytest.raises(AttributeError):
            g.ax

    def test_col_wrap(self):

        n = len(self.df.d.unique())

        g = ag.FacetGrid(self.df, col="d")
        assert g.axes.shape == (1, n)
        assert g.facet_axis(0, 8) is g.axes[0, 8]

        g_wrap = ag.FacetGrid(self.df, col="d", col_wrap=4)
        assert g_wrap.axes.shape == (n,)
        assert g_wrap.facet_axis(0, 8) is g_wrap.axes[8]
        assert g_wrap._ncol == 4
        assert g_wrap._nrow == (n / 4)

        with pytest.raises(ValueError):
            g = ag.FacetGrid(self.df, row="b", col="d", col_wrap=4)

        df = self.df.copy()
        df.loc[df.d == "j"] = np.nan
        g_missing = ag.FacetGrid(df, col="d")
        assert g_missing.axes.shape == (1, n - 1)

        g_missing_wrap = ag.FacetGrid(df, col="d", col_wrap=4)
        assert g_missing_wrap.axes.shape == (n - 1,)

        g = ag.FacetGrid(self.df, col="d", col_wrap=1)
        assert len(list(g.facet_data())) == n

    def test_normal_axes(self):

        null = np.empty(0, object).flat

        g = ag.FacetGrid(self.df)
        npt.assert_array_equal(g._bottom_axes, g.axes.flat)
        npt.assert_array_equal(g._not_bottom_axes, null)
        npt.assert_array_equal(g._left_axes, g.axes.flat)
        npt.assert_array_equal(g._not_left_axes, null)
        npt.assert_array_equal(g._inner_axes, null)

        g = ag.FacetGrid(self.df, col="c")
        npt.assert_array_equal(g._bottom_axes, g.axes.flat)
        npt.assert_array_equal(g._not_bottom_axes, null)
        npt.assert_array_equal(g._left_axes, g.axes[:, 0].flat)
        npt.assert_array_equal(g._not_left_axes, g.axes[:, 1:].flat)
        npt.assert_array_equal(g._inner_axes, null)

        g = ag.FacetGrid(self.df, row="c")
        npt.assert_array_equal(g._bottom_axes, g.axes[-1, :].flat)
        npt.assert_array_equal(g._not_bottom_axes, g.axes[:-1, :].flat)
        npt.assert_array_equal(g._left_axes, g.axes.flat)
        npt.assert_array_equal(g._not_left_axes, null)
        npt.assert_array_equal(g._inner_axes, null)

        g = ag.FacetGrid(self.df, col="a", row="c")
        npt.assert_array_equal(g._bottom_axes, g.axes[-1, :].flat)
        npt.assert_array_equal(g._not_bottom_axes, g.axes[:-1, :].flat)
        npt.assert_array_equal(g._left_axes, g.axes[:, 0].flat)
        npt.assert_array_equal(g._not_left_axes, g.axes[:, 1:].flat)
        npt.assert_array_equal(g._inner_axes, g.axes[:-1, 1:].flat)

    def test_wrapped_axes(self):

        null = np.empty(0, object).flat

        g = ag.FacetGrid(self.df, col="a", col_wrap=2)
        npt.assert_array_equal(g._bottom_axes,
                               g.axes[np.array([1, 2])].flat)
        npt.assert_array_equal(g._not_bottom_axes, g.axes[:1].flat)
        npt.assert_array_equal(g._left_axes, g.axes[np.array([0, 2])].flat)
        npt.assert_array_equal(g._not_left_axes, g.axes[np.array([1])].flat)
        npt.assert_array_equal(g._inner_axes, null)

    def test_axes_dict(self):

        g = ag.FacetGrid(self.df)
        assert isinstance(g.axes_dict, dict)
        assert not g.axes_dict

        g = ag.FacetGrid(self.df, row="c")
        assert list(g.axes_dict.keys()) == g.row_names
        for (name, ax) in zip(g.row_names, g.axes.flat):
            assert g.axes_dict[name] is ax

        g = ag.FacetGrid(self.df, col="c")
        assert list(g.axes_dict.keys()) == g.col_names
        for (name, ax) in zip(g.col_names, g.axes.flat):
            assert g.axes_dict[name] is ax

        g = ag.FacetGrid(self.df, col="a", col_wrap=2)
        assert list(g.axes_dict.keys()) == g.col_names
        for (name, ax) in zip(g.col_names, g.axes.flat):
            assert g.axes_dict[name] is ax

        g = ag.FacetGrid(self.df, row="a", col="c")
        for (row_var, col_var), ax in g.axes_dict.items():
            i = g.row_names.index(row_var)
            j = g.col_names.index(col_var)
            assert g.axes[i, j] is ax

    def test_figure_size(self):

        g = ag.FacetGrid(self.df, row="a", col="b")
        npt.assert_array_equal(g.figure.get_size_inches(), (6, 9))

        g = ag.FacetGrid(self.df, row="a", col="b", height=6)
        npt.assert_array_equal(g.figure.get_size_inches(), (12, 18))

        g = ag.FacetGrid(self.df, col="c", height=4, aspect=.5)
        npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))

    def test_figure_size_with_legend(self):

        g = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5)
        npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))
        g.add_legend()
        assert g.figure.get_size_inches()[0] > 6

        g = ag.FacetGrid(self.df, col="a", hue="c", height=4, aspect=.5,
                         legend_out=False)
        npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))
        g.add_legend()
        npt.assert_array_equal(g.figure.get_size_inches(), (6, 4))

    def test_legend_data(self):

        g = ag.FacetGrid(self.df, hue="a")
        g.map(plt.plot, "x", "y")
        g.add_legend()
        palette = color_palette(n_colors=3)

        assert g._legend.get_title().get_text() == "a"

        a_levels = sorted(self.df.a.unique())

        lines = g._legend.get_lines()
        assert len(lines) == len(a_levels)

        for line, hue in zip(lines, palette):
            assert_colors_equal(line.get_color(), hue)

        labels = g._legend.get_texts()
        assert len(labels) == len(a_levels)

        for label, level in zip(labels, a_levels):
            assert label.get_text() == level

    def test_legend_data_missing_level(self):

        g = ag.FacetGrid(self.df, hue="a", hue_order=list("azbc"))
        g.map(plt.plot, "x", "y")
        g.add_legend()

        c1, c2, c3, c4 = color_palette(n_colors=4)
        palette = [c1, c3, c4]

        assert g._legend.get_title().get_text() == "a"

        a_levels = sorted(self.df.a.unique())

        lines = g._legend.get_lines()
        assert len(lines) == len(a_levels)

        for line, hue in zip(lines, palette):
            assert_colors_equal(line.get_color(), hue)

        labels = g._legend.get_texts()
        assert len(labels) == 4

        for label, level in zip(labels, list("azbc")):
            assert label.get_text() == level

    def test_get_boolean_legend_data(self):

        self.df["b_bool"] = self.df.b == "m"
        g = ag.FacetGrid(self.df, hue="b_bool")
        g.map(plt.plot, "x", "y")
        g.add_legend()
        palette = color_palette(n_colors=2)

        assert g._legend.get_title().get_text() == "b_bool"

        b_levels = list(map(str, categorical_order(self.df.b_bool)))

        lines = g._legend.get_lines()
        assert len(lines) == len(b_levels)

        for line, hue in zip(lines, palette):
            assert_colors_equal(line.get_color(), hue)

        labels = g._legend.get_texts()
        assert len(labels) == len(b_levels)

        for label, level in zip(labels, b_levels):
            assert label.get_text() == level

    def test_legend_tuples(self):

        g = ag.FacetGrid(self.df, hue="a")
        g.map(plt.plot, "x", "y")

        handles, labels = g.ax.get_legend_handles_labels()
        label_tuples = [("", l) for l in labels]
        legend_data = dict(zip(label_tuples, handles))
        g.add_legend(legend_data, label_tuples)
        for entry, label in zip(g._legend.get_texts(), labels):
            assert entry.get_text() == label

    def test_legend_options(self):

        g = ag.FacetGrid(self.df, hue="b")
        g.map(plt.plot, "x", "y")
        g.add_legend()

        g1 = ag.FacetGrid(self.df, hue="b", legend_out=False)
        g1.add_legend(adjust_subtitles=True)

        g1 = ag.FacetGrid(self.df, hue="b", legend_out=False)
        g1.add_legend(adjust_subtitles=False)

    def test_legendout_with_colwrap(self):

        g = ag.FacetGrid(self.df, col="d", hue='b',
                         col_wrap=4, legend_out=False)
        g.map(plt.plot, "x", "y", linewidth=3)
        g.add_legend()

    def test_legend_tight_layout(self):

        g = ag.FacetGrid(self.df, hue='b')
        g.map(plt.plot, "x", "y", linewidth=3)
        g.add_legend()
        g.tight_layout()

        axes_right_edge = g.ax.get_window_extent().xmax
        legend_left_edge = g._legend.get_window_extent().xmin

        assert axes_right_edge < legend_left_edge

    def test_subplot_kws(self):

        g = ag.FacetGrid(self.df, despine=False,
                         subplot_kws=dict(projection="polar"))
        for ax in g.axes.flat:
            assert "PolarAxes" in ax.__class__.__name__

    def test_gridspec_kws(self):
        ratios = [3, 1, 2]

        gskws = dict(width_ratios=ratios)
        g = ag.FacetGrid(self.df, col='c', row='a', gridspec_kws=gskws)

        for ax in g.axes.flat:
            ax.set_xticks([])
            ax.set_yticks([])

        g.figure.tight_layout()

        for (l, m, r) in g.axes:
            assert l.get_position().width > m.get_position().width
            assert r.get_position().width > m.get_position().width

    def test_gridspec_kws_col_wrap(self):
        ratios = [3, 1, 2, 1, 1]

        gskws = dict(width_ratios=ratios)
        with pytest.warns(UserWarning):
            ag.FacetGrid(self.df, col='d', col_wrap=5, gridspec_kws=gskws)

    def test_data_generator(self):

        g = ag.FacetGrid(self.df, row="a")
        d = list(g.facet_data())
        assert len(d) == 3

        tup, data = d[0]
        assert tup == (0, 0, 0)
        assert (data["a"] == "a").all()

        tup, data = d[1]
        assert tup == (1, 0, 0)
        assert (data["a"] == "b").all()

        g = ag.FacetGrid(self.df, row="a", col="b")
        d = list(g.facet_data())
        assert len(d) == 6

        tup, data = d[0]
        assert tup == (0, 0, 0)
        assert (data["a"] == "a").all()
        assert (data["b"] == "m").all()

        tup, data = d[1]
        assert tup == (0, 1, 0)
        assert (data["a"] == "a").all()
        assert (data["b"] == "n").all()

        tup, data = d[2]
        assert tup == (1, 0, 0)
        assert (data["a"] == "b").all()
        assert (data["b"] == "m").all()

        g = ag.FacetGrid(self.df, hue="c")
        d = list(g.facet_data())
        assert len(d) == 3
        tup, data = d[1]
        assert tup == (0, 0, 1)
        assert (data["c"] == "u").all()

    def test_map(self):

        g = ag.FacetGrid(self.df, row="a", col="b", hue="c")
        g.map(plt.plot, "x", "y", linewidth=3)

        lines = g.axes[0, 0].lines
        assert len(lines) == 3

        line1, _, _ = lines
        assert line1.get_linewidth() == 3
        x, y = line1.get_data()
        mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t")
        npt.assert_array_equal(x, self.df.x[mask])
        npt.assert_array_equal(y, self.df.y[mask])

    def test_map_dataframe(self):

        g = ag.FacetGrid(self.df, row="a", col="b", hue="c")

        def plot(x, y, data=None, **kws):
            plt.plot(data[x], data[y], **kws)
        # Modify __module__ so this doesn't look like a seaborn function
        plot.__module__ = "test"

        g.map_dataframe(plot, "x", "y", linestyle="--")

        lines = g.axes[0, 0].lines
        assert len(g.axes[0, 0].lines) == 3

        line1, _, _ = lines
        assert line1.get_linestyle() == "--"
        x, y = line1.get_data()
        mask = (self.df.a == "a") & (self.df.b == "m") & (self.df.c == "t")
        npt.assert_array_equal(x, self.df.x[mask])
        npt.assert_array_equal(y, self.df.y[mask])

    def test_set(self):

        g = ag.FacetGrid(self.df, row="a", col="b")
        xlim = (-2, 5)
        ylim = (3, 6)
        xticks = [-2, 0, 3, 5]
        yticks = [3, 4.5, 6]
        g.set(xlim=xlim, ylim=ylim, xticks=xticks, yticks=yticks)
        for ax in g.axes.flat:
            npt.assert_array_equal(ax.get_xlim(), xlim)
            npt.assert_array_equal(ax.get_ylim(), ylim)
            npt.assert_array_equal(ax.get_xticks(), xticks)
            npt.assert_array_equal(ax.get_yticks(), yticks)

    def test_set_titles(self):

        g = ag.FacetGrid(self.df, row="a", col="b")
        g.map(plt.plot, "x", "y")

        # Test the default titles
        assert g.axes[0, 0].get_title() == "a = a | b = m"
        assert g.axes[0, 1].get_title() == "a = a | b = n"
        assert g.axes[1, 0].get_title() == "a = b | b = m"

        # Test a provided title
        g.set_titles("{row_var} == {row_name} \\/ {col_var} == {col_name}")
        assert g.axes[0, 0].get_title() == "a == a \\/ b == m"
        assert g.axes[0, 1].get_title() == "a == a \\/ b == n"
        assert g.axes[1, 0].get_title() == "a == b \\/ b == m"

        # Test a single row
        g = ag.FacetGrid(self.df, col="b")
        g.map(plt.plot, "x", "y")

        # Test the default titles
        assert g.axes[0, 0].get_title() == "b = m"
        assert g.axes[0, 1].get_title() == "b = n"

        # test with dropna=False
        g = ag.FacetGrid(self.df, col="b", hue="b", dropna=False)
        g.map(plt.plot, 'x', 'y')

    def test_set_titles_margin_titles(self):

        g = ag.FacetGrid(self.df, row="a", col="b", margin_titles=True)
        g.map(plt.plot, "x", "y")

        # Test the default titles
        assert g.axes[0, 0].get_title() == "b = m"
        assert g.axes[0, 1].get_title() == "b = n"
        assert g.axes[1, 0].get_title() == ""

        # Test the row "titles"
        assert g.axes[0, 1].texts[0].get_text() == "a = a"
        assert g.axes[1, 1].texts[0].get_text() == "a = b"
        assert g.axes[0, 1].texts[0] is g._margin_titles_texts[0]

        # Test provided titles
        g.set_titles(col_template="{col_name}", row_template="{row_name}")
        assert g.axes[0, 0].get_title() == "m"
        assert g.axes[0, 1].get_title() == "n"
        assert g.axes[1, 0].get_title() == ""

        assert len(g.axes[1, 1].texts) == 1
        assert g.axes[1, 1].texts[0].get_text() == "b"

    def test_set_ticklabels(self):

        g = ag.FacetGrid(self.df, row="a", col="b")
        g.map(plt.plot, "x", "y")

        ax = g.axes[-1, 0]
        xlab = [l.get_text() + "h" for l in ax.get_xticklabels()]
        ylab = [l.get_text() + "i" for l in ax.get_yticklabels()]

        g.set_xticklabels(xlab)
        g.set_yticklabels(ylab)
        got_x = [l.get_text() for l in g.axes[-1, 1].get_xticklabels()]
        got_y = [l.get_text() for l in g.axes[0, 0].get_yticklabels()]
        npt.assert_array_equal(got_x, xlab)
        npt.assert_array_equal(got_y, ylab)

        x, y = np.arange(10), np.arange(10)
        df = pd.DataFrame(np.c_[x, y], columns=["x", "y"])
        g = ag.FacetGrid(df).map_dataframe(pointplot, x="x", y="y", order=x)
        g.set_xticklabels(step=2)
        got_x = [int(l.get_text()) for l in g.axes[0, 0].get_xticklabels()]
        npt.assert_array_equal(x[::2], got_x)

        g = ag.FacetGrid(self.df, col="d", col_wrap=5)
        g.map(plt.plot, "x", "y")
        g.set_xticklabels(rotation=45)
        g.set_yticklabels(rotation=75)
        for ax in g._bottom_axes:
            for l in ax.get_xticklabels():
                assert l.get_rotation() == 45
        for ax in g._left_axes:
            for l in ax.get_yticklabels():
                assert l.get_rotation() == 75

    def test_set_axis_labels(self):

        g = ag.FacetGrid(self.df, row="a", col="b")
        g.map(plt.plot, "x", "y")
        xlab = 'xx'
        ylab = 'yy'

        g.set_axis_labels(xlab, ylab)

        got_x = [ax.get_xlabel() for ax in g.axes[-1, :]]
        got_y = [ax.get_ylabel() for ax in g.axes[:, 0]]
        npt.assert_array_equal(got_x, xlab)
        npt.assert_array_equal(got_y, ylab)

        for ax in g.axes.flat:
            ax.set(xlabel="x", ylabel="y")

        g.set_axis_labels(xlab, ylab)
        for ax in g._not_bottom_axes:
            assert not ax.get_xlabel()
        for ax in g._not_left_axes:
            assert not ax.get_ylabel()

    def test_axis_lims(self):

        g = ag.FacetGrid(self.df, row="a", col="b", xlim=(0, 4), ylim=(-2, 3))
        assert g.axes[0, 0].get_xlim() == (0, 4)
        assert g.axes[0, 0].get_ylim() == (-2, 3)

    def test_data_orders(self):

        g = ag.FacetGrid(self.df, row="a", col="b", hue="c")

        assert g.row_names == list("abc")
        assert g.col_names == list("mn")
        assert g.hue_names == list("tuv")
        assert g.axes.shape == (3, 2)

        g = ag.FacetGrid(self.df, row="a", col="b", hue="c",
                         row_order=list("bca"),
                         col_order=list("nm"),
                         hue_order=list("vtu"))

        assert g.row_names == list("bca")
        assert g.col_names == list("nm")
        assert g.hue_names == list("vtu")
        assert g.axes.shape == (3, 2)

        g = ag.FacetGrid(self.df, row="a", col="b", hue="c",
                         row_order=list("bcda"),
                         col_order=list("nom"),
                         hue_order=list("qvtu"))

        assert g.row_names == list("bcda")
        assert g.col_names == list("nom")
        assert g.hue_names == list("qvtu")
        assert g.axes.shape == (4, 3)

    def test_palette(self):

        rcmod.set()

        g = ag.FacetGrid(self.df, hue="c")
        assert g._colors == color_palette(n_colors=len(self.df.c.unique()))

        g = ag.FacetGrid(self.df, hue="d")
        assert g._colors == color_palette("husl", len(self.df.d.unique()))

        g = ag.FacetGrid(self.df, hue="c", palette="Set2")
        assert g._colors == color_palette("Set2", len(self.df.c.unique()))

        dict_pal = dict(t="red", u="green", v="blue")
        list_pal = color_palette(["red", "green", "blue"], 3)
        g = ag.FacetGrid(self.df, hue="c", palette=dict_pal)
        assert g._colors == list_pal

        list_pal = color_palette(["green", "blue", "red"], 3)
        g = ag.FacetGrid(self.df, hue="c", hue_order=list("uvt"),
                         palette=dict_pal)
        assert g._colors == list_pal

    def test_hue_kws(self):

        kws = dict(marker=["o", "s", "D"])
        g = ag.FacetGrid(self.df, hue="c", hue_kws=kws)
        g.map(plt.plot, "x", "y")

        for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
            assert line.get_marker() == marker

    def test_dropna(self):

        df = self.df.copy()
        hasna = pd.Series(np.tile(np.arange(6), 10), dtype=float)
        hasna[hasna == 5] = np.nan
        df["hasna"] = hasna
        g = ag.FacetGrid(df, dropna=False, row="hasna")
        assert g._not_na.sum() == 60

        g = ag.FacetGrid(df, dropna=True, row="hasna")
        assert g._not_na.sum() == 50

    def test_categorical_column_missing_categories(self):

        df = self.df.copy()
        df['a'] = df['a'].astype('category')

        g = ag.FacetGrid(df[df['a'] == 'a'], col="a", col_wrap=1)

        assert g.axes.shape == (len(df['a'].cat.categories),)

    def test_categorical_warning(self):

        g = ag.FacetGrid(self.df, col="b")
        with pytest.warns(UserWarning):
            g.map(pointplot, "b", "x")

    def test_refline(self):

        g = ag.FacetGrid(self.df, row="a", col="b")
        g.refline()
        for ax in g.axes.flat:
            assert not ax.lines

        refx = refy = 0.5
        hline = np.array([[0, refy], [1, refy]])
        vline = np.array([[refx, 0], [refx, 1]])
        g.refline(x=refx, y=refy)
        for ax in g.axes.flat:
            assert ax.lines[0].get_color() == '.5'
            assert ax.lines[0].get_linestyle() == '--'
            assert len(ax.lines) == 2
            npt.assert_array_equal(ax.lines[0].get_xydata(), vline)
            npt.assert_array_equal(ax.lines[1].get_xydata(), hline)

        color, linestyle = 'red', '-'
        g.refline(x=refx, color=color, linestyle=linestyle)
        npt.assert_array_equal(g.axes[0, 0].lines[-1].get_xydata(), vline)
        assert g.axes[0, 0].lines[-1].get_color() == color
        assert g.axes[0, 0].lines[-1].get_linestyle() == linestyle

    def test_apply(self, long_df):

        def f(grid, color):
            grid.figure.set_facecolor(color)

        color = (.1, .6, .3, .9)
        g = ag.FacetGrid(long_df)
        res = g.apply(f, color)
        assert res is g
        assert g.figure.get_facecolor() == color

    def test_pipe(self, long_df):

        def f(grid, color):
            grid.figure.set_facecolor(color)
            return color

        color = (.1, .6, .3, .9)
        g = ag.FacetGrid(long_df)
        res = g.pipe(f, color)
        assert res == color
        assert g.figure.get_facecolor() == color

    def test_tick_params(self):

        g = ag.FacetGrid(self.df, row="a", col="b")
        color = "blue"
        pad = 3
        g.tick_params(pad=pad, color=color)
        for ax in g.axes.flat:
            for axis in ["xaxis", "yaxis"]:
                for tick in getattr(ax, axis).get_major_ticks():
                    assert mpl.colors.same_color(tick.tick1line.get_color(), color)
                    assert mpl.colors.same_color(tick.tick2line.get_color(), color)
                    assert tick.get_pad() == pad

    @pytest.mark.skipif(
        condition=not hasattr(pd.api, "interchange"),
        reason="Tests behavior assuming support for dataframe interchange"
    )
    def test_data_interchange(self, mock_long_df, long_df):

        g = ag.FacetGrid(mock_long_df, col="a", row="b")
        g.map(scatterplot, "x", "y")

        assert g.axes.shape == (long_df["b"].nunique(), long_df["a"].nunique())
        for ax in g.axes.flat:
            assert len(ax.collections) == 1


class TestPairGrid:

    rs = np.random.RandomState(sum(map(ord, "PairGrid")))
    df = pd.DataFrame(dict(x=rs.normal(size=60),
                           y=rs.randint(0, 4, size=(60)),
                           z=rs.gamma(3, size=60),
                           a=np.repeat(list("abc"), 20),
                           b=np.repeat(list("abcdefghijkl"), 5)))

    def test_self_data(self):

        g = ag.PairGrid(self.df)
        assert g.data is self.df

    def test_ignore_datelike_data(self):

        df = self.df.copy()
        df['date'] = pd.date_range('2010-01-01', periods=len(df), freq='D')
        result = ag.PairGrid(self.df).data
        expected = df.drop('date', axis=1)
        tm.assert_frame_equal(result, expected)

    def test_self_figure(self):

        g = ag.PairGrid(self.df)
        assert isinstance(g.figure, plt.Figure)
        assert g.figure is g._figure

    def test_self_axes(self):

        g = ag.PairGrid(self.df)
        for ax in g.axes.flat:
            assert isinstance(ax, plt.Axes)

    def test_default_axes(self):

        g = ag.PairGrid(self.df)
        assert g.axes.shape == (3, 3)
        assert g.x_vars == ["x", "y", "z"]
        assert g.y_vars == ["x", "y", "z"]
        assert g.square_grid

    @pytest.mark.parametrize("vars", [["z", "x"], np.array(["z", "x"])])
    def test_specific_square_axes(self, vars):

        g = ag.PairGrid(self.df, vars=vars)
        assert g.axes.shape == (len(vars), len(vars))
        assert g.x_vars == list(vars)
        assert g.y_vars == list(vars)
        assert g.square_grid

    def test_remove_hue_from_default(self):

        hue = "z"
        g = ag.PairGrid(self.df, hue=hue)
        assert hue not in g.x_vars
        assert hue not in g.y_vars

        vars = ["x", "y", "z"]
        g = ag.PairGrid(self.df, hue=hue, vars=vars)
        assert hue in g.x_vars
        assert hue in g.y_vars

    @pytest.mark.parametrize(
        "x_vars, y_vars",
        [
            (["x", "y"], ["z", "y", "x"]),
            (["x", "y"], "z"),
            (np.array(["x", "y"]), np.array(["z", "y", "x"])),
        ],
    )
    def test_specific_nonsquare_axes(self, x_vars, y_vars):

        g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
        assert g.axes.shape == (len(y_vars), len(x_vars))
        assert g.x_vars == list(x_vars)
        assert g.y_vars == list(y_vars)
        assert not g.square_grid

    def test_corner(self):

        plot_vars = ["x", "y", "z"]
        g = ag.PairGrid(self.df, vars=plot_vars, corner=True)
        corner_size = sum(i + 1 for i in range(len(plot_vars)))
        assert len(g.figure.axes) == corner_size

        g.map_diag(plt.hist)
        assert len(g.figure.axes) == (corner_size + len(plot_vars))

        for ax in np.diag(g.axes):
            assert not ax.yaxis.get_visible()

        plot_vars = ["x", "y", "z"]
        g = ag.PairGrid(self.df, vars=plot_vars, corner=True)
        g.map(scatterplot)
        assert len(g.figure.axes) == corner_size
        assert g.axes[0, 0].get_ylabel() == "x"

    def test_size(self):

        g1 = ag.PairGrid(self.df, height=3)
        npt.assert_array_equal(g1.fig.get_size_inches(), (9, 9))

        g2 = ag.PairGrid(self.df, height=4, aspect=.5)
        npt.assert_array_equal(g2.fig.get_size_inches(), (6, 12))

        g3 = ag.PairGrid(self.df, y_vars=["z"], x_vars=["x", "y"],
                         height=2, aspect=2)
        npt.assert_array_equal(g3.fig.get_size_inches(), (8, 2))

    def test_empty_grid(self):

        with pytest.raises(ValueError, match="No variables found"):
            ag.PairGrid(self.df[["a", "b"]])

    def test_map(self):

        vars = ["x", "y", "z"]
        g1 = ag.PairGrid(self.df)
        g1.map(plt.scatter)

        for i, axes_i in enumerate(g1.axes):
            for j, ax in enumerate(axes_i):
                x_in = self.df[vars[j]]
                y_in = self.df[vars[i]]
                x_out, y_out = ax.collections[0].get_offsets().T
                npt.assert_array_equal(x_in, x_out)
                npt.assert_array_equal(y_in, y_out)

        g2 = ag.PairGrid(self.df, hue="a")
        g2.map(plt.scatter)

        for i, axes_i in enumerate(g2.axes):
            for j, ax in enumerate(axes_i):
                x_in = self.df[vars[j]]
                y_in = self.df[vars[i]]
                for k, k_level in enumerate(self.df.a.unique()):
                    x_in_k = x_in[self.df.a == k_level]
                    y_in_k = y_in[self.df.a == k_level]
                    x_out, y_out = ax.collections[k].get_offsets().T
                npt.assert_array_equal(x_in_k, x_out)
                npt.assert_array_equal(y_in_k, y_out)

    def test_map_nonsquare(self):

        x_vars = ["x"]
        y_vars = ["y", "z"]
        g = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
        g.map(plt.scatter)

        x_in = self.df.x
        for i, i_var in enumerate(y_vars):
            ax = g.axes[i, 0]
            y_in = self.df[i_var]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

    def test_map_lower(self):

        vars = ["x", "y", "z"]
        g = ag.PairGrid(self.df)
        g.map_lower(plt.scatter)

        for i, j in zip(*np.tril_indices_from(g.axes, -1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.triu_indices_from(g.axes)):
            ax = g.axes[i, j]
            assert len(ax.collections) == 0

    def test_map_upper(self):

        vars = ["x", "y", "z"]
        g = ag.PairGrid(self.df)
        g.map_upper(plt.scatter)

        for i, j in zip(*np.triu_indices_from(g.axes, 1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.tril_indices_from(g.axes)):
            ax = g.axes[i, j]
            assert len(ax.collections) == 0

    def test_map_mixed_funcsig(self):

        vars = ["x", "y", "z"]
        g = ag.PairGrid(self.df, vars=vars)
        g.map_lower(scatterplot)
        g.map_upper(plt.scatter)

        for i, j in zip(*np.triu_indices_from(g.axes, 1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

    def test_map_diag(self):

        g = ag.PairGrid(self.df)
        g.map_diag(plt.hist)

        for var, ax in zip(g.diag_vars, g.diag_axes):
            assert len(ax.patches) == 10
            assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()

        g = ag.PairGrid(self.df, hue="a")
        g.map_diag(plt.hist)

        for ax in g.diag_axes:
            assert len(ax.patches) == 30

        g = ag.PairGrid(self.df, hue="a")
        g.map_diag(plt.hist, histtype='step')

        for ax in g.diag_axes:
            for ptch in ax.patches:
                assert not ptch.fill

    def test_map_diag_rectangular(self):

        x_vars = ["x", "y"]
        y_vars = ["x", "z", "y"]
        g1 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
        g1.map_diag(plt.hist)
        g1.map_offdiag(plt.scatter)

        assert set(g1.diag_vars) == (set(x_vars) & set(y_vars))

        for var, ax in zip(g1.diag_vars, g1.diag_axes):
            assert len(ax.patches) == 10
            assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()

        for j, x_var in enumerate(x_vars):
            for i, y_var in enumerate(y_vars):

                ax = g1.axes[i, j]
                if x_var == y_var:
                    diag_ax = g1.diag_axes[j]  # because fewer x than y vars
                    assert ax.bbox.bounds == diag_ax.bbox.bounds

                else:
                    x, y = ax.collections[0].get_offsets().T
                    assert_array_equal(x, self.df[x_var])
                    assert_array_equal(y, self.df[y_var])

        g2 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars, hue="a")
        g2.map_diag(plt.hist)
        g2.map_offdiag(plt.scatter)

        assert set(g2.diag_vars) == (set(x_vars) & set(y_vars))

        for ax in g2.diag_axes:
            assert len(ax.patches) == 30

        x_vars = ["x", "y", "z"]
        y_vars = ["x", "z"]
        g3 = ag.PairGrid(self.df, x_vars=x_vars, y_vars=y_vars)
        g3.map_diag(plt.hist)
        g3.map_offdiag(plt.scatter)

        assert set(g3.diag_vars) == (set(x_vars) & set(y_vars))

        for var, ax in zip(g3.diag_vars, g3.diag_axes):
            assert len(ax.patches) == 10
            assert pytest.approx(ax.patches[0].get_x()) == self.df[var].min()

        for j, x_var in enumerate(x_vars):
            for i, y_var in enumerate(y_vars):

                ax = g3.axes[i, j]
                if x_var == y_var:
                    diag_ax = g3.diag_axes[i]  # because fewer y than x vars
                    assert ax.bbox.bounds == diag_ax.bbox.bounds
                else:
                    x, y = ax.collections[0].get_offsets().T
                    assert_array_equal(x, self.df[x_var])
                    assert_array_equal(y, self.df[y_var])

    def test_map_diag_color(self):

        color = "red"

        g1 = ag.PairGrid(self.df)
        g1.map_diag(plt.hist, color=color)

        for ax in g1.diag_axes:
            for patch in ax.patches:
                assert_colors_equal(patch.get_facecolor(), color)

        g2 = ag.PairGrid(self.df)
        g2.map_diag(kdeplot, color='red')

        for ax in g2.diag_axes:
            for line in ax.lines:
                assert_colors_equal(line.get_color(), color)

    def test_map_diag_palette(self):

        palette = "muted"
        pal = color_palette(palette, n_colors=len(self.df.a.unique()))
        g = ag.PairGrid(self.df, hue="a", palette=palette)
        g.map_diag(kdeplot)

        for ax in g.diag_axes:
            for line, color in zip(ax.lines[::-1], pal):
                assert_colors_equal(line.get_color(), color)

    def test_map_diag_and_offdiag(self):

        vars = ["x", "y", "z"]
        g = ag.PairGrid(self.df)
        g.map_offdiag(plt.scatter)
        g.map_diag(plt.hist)

        for ax in g.diag_axes:
            assert len(ax.patches) == 10

        for i, j in zip(*np.triu_indices_from(g.axes, 1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.tril_indices_from(g.axes, -1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.diag_indices_from(g.axes)):
            ax = g.axes[i, j]
            assert len(ax.collections) == 0

    def test_diag_sharey(self):

        g = ag.PairGrid(self.df, diag_sharey=True)
        g.map_diag(kdeplot)
        for ax in g.diag_axes[1:]:
            assert ax.get_ylim() == g.diag_axes[0].get_ylim()

    def test_map_diag_matplotlib(self):

        bins = 10
        g = ag.PairGrid(self.df)
        g.map_diag(plt.hist, bins=bins)
        for ax in g.diag_axes:
            assert len(ax.patches) == bins

        levels = len(self.df["a"].unique())
        g = ag.PairGrid(self.df, hue="a")
        g.map_diag(plt.hist, bins=bins)
        for ax in g.diag_axes:
            assert len(ax.patches) == (bins * levels)

    def test_palette(self):

        rcmod.set()

        g = ag.PairGrid(self.df, hue="a")
        assert g.palette == color_palette(n_colors=len(self.df.a.unique()))

        g = ag.PairGrid(self.df, hue="b")
        assert g.palette == color_palette("husl", len(self.df.b.unique()))

        g = ag.PairGrid(self.df, hue="a", palette="Set2")
        assert g.palette == color_palette("Set2", len(self.df.a.unique()))

        dict_pal = dict(a="red", b="green", c="blue")
        list_pal = color_palette(["red", "green", "blue"])
        g = ag.PairGrid(self.df, hue="a", palette=dict_pal)
        assert g.palette == list_pal

        list_pal = color_palette(["blue", "red", "green"])
        g = ag.PairGrid(self.df, hue="a", hue_order=list("cab"),
                        palette=dict_pal)
        assert g.palette == list_pal

    def test_hue_kws(self):

        kws = dict(marker=["o", "s", "d", "+"])
        g = ag.PairGrid(self.df, hue="a", hue_kws=kws)
        g.map(plt.plot)

        for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
            assert line.get_marker() == marker

        g = ag.PairGrid(self.df, hue="a", hue_kws=kws,
                        hue_order=list("dcab"))
        g.map(plt.plot)

        for line, marker in zip(g.axes[0, 0].lines, kws["marker"]):
            assert line.get_marker() == marker

    def test_hue_order(self):

        order = list("dcab")
        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map(plt.plot)

        for line, level in zip(g.axes[1, 0].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

        plt.close("all")

        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map_diag(plt.plot)

        for line, level in zip(g.axes[0, 0].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

        plt.close("all")

        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map_lower(plt.plot)

        for line, level in zip(g.axes[1, 0].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

        plt.close("all")

        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map_upper(plt.plot)

        for line, level in zip(g.axes[0, 1].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

        plt.close("all")

    def test_hue_order_missing_level(self):

        order = list("dcaeb")
        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map(plt.plot)

        for line, level in zip(g.axes[1, 0].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

        plt.close("all")

        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map_diag(plt.plot)

        for line, level in zip(g.axes[0, 0].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

        plt.close("all")

        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map_lower(plt.plot)

        for line, level in zip(g.axes[1, 0].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "x"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "y"])

        plt.close("all")

        g = ag.PairGrid(self.df, hue="a", hue_order=order)
        g.map_upper(plt.plot)

        for line, level in zip(g.axes[0, 1].lines, order):
            x, y = line.get_xydata().T
            npt.assert_array_equal(x, self.df.loc[self.df.a == level, "y"])
            npt.assert_array_equal(y, self.df.loc[self.df.a == level, "x"])

        plt.close("all")

    def test_hue_in_map(self, long_df):

        g = ag.PairGrid(long_df, vars=["x", "y"])
        g.map(scatterplot, hue=long_df["a"])
        ax = g.axes.flat[0]
        points = ax.collections[0]
        assert len(set(map(tuple, points.get_facecolors()))) == 3

    def test_nondefault_index(self):

        df = self.df.copy().set_index("b")

        plot_vars = ["x", "y", "z"]
        g1 = ag.PairGrid(df)
        g1.map(plt.scatter)

        for i, axes_i in enumerate(g1.axes):
            for j, ax in enumerate(axes_i):
                x_in = self.df[plot_vars[j]]
                y_in = self.df[plot_vars[i]]
                x_out, y_out = ax.collections[0].get_offsets().T
                npt.assert_array_equal(x_in, x_out)
                npt.assert_array_equal(y_in, y_out)

        g2 = ag.PairGrid(df, hue="a")
        g2.map(plt.scatter)

        for i, axes_i in enumerate(g2.axes):
            for j, ax in enumerate(axes_i):
                x_in = self.df[plot_vars[j]]
                y_in = self.df[plot_vars[i]]
                for k, k_level in enumerate(self.df.a.unique()):
                    x_in_k = x_in[self.df.a == k_level]
                    y_in_k = y_in[self.df.a == k_level]
                    x_out, y_out = ax.collections[k].get_offsets().T
                    npt.assert_array_equal(x_in_k, x_out)
                    npt.assert_array_equal(y_in_k, y_out)

    @pytest.mark.parametrize("func", [scatterplot, plt.scatter])
    def test_dropna(self, func):

        df = self.df.copy()
        n_null = 20
        df.loc[np.arange(n_null), "x"] = np.nan

        plot_vars = ["x", "y", "z"]

        g1 = ag.PairGrid(df, vars=plot_vars, dropna=True)
        g1.map(func)

        for i, axes_i in enumerate(g1.axes):
            for j, ax in enumerate(axes_i):
                x_in = df[plot_vars[j]]
                y_in = df[plot_vars[i]]
                x_out, y_out = ax.collections[0].get_offsets().T

                n_valid = (x_in * y_in).notnull().sum()

                assert n_valid == len(x_out)
                assert n_valid == len(y_out)

        g1.map_diag(histplot)
        for i, ax in enumerate(g1.diag_axes):
            var = plot_vars[i]
            count = sum(p.get_height() for p in ax.patches)
            assert count == df[var].notna().sum()

    def test_histplot_legend(self):

        # Tests _extract_legend_handles
        g = ag.PairGrid(self.df, vars=["x", "y"], hue="a")
        g.map_offdiag(histplot)
        g.add_legend()

        assert len(get_legend_handles(g._legend)) == len(self.df["a"].unique())

    def test_pairplot(self):

        vars = ["x", "y", "z"]
        g = ag.pairplot(self.df)

        for ax in g.diag_axes:
            assert len(ax.patches) > 1

        for i, j in zip(*np.triu_indices_from(g.axes, 1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.tril_indices_from(g.axes, -1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.diag_indices_from(g.axes)):
            ax = g.axes[i, j]
            assert len(ax.collections) == 0

        g = ag.pairplot(self.df, hue="a")
        n = len(self.df.a.unique())

        for ax in g.diag_axes:
            assert len(ax.collections) == n

    def test_pairplot_reg(self):

        vars = ["x", "y", "z"]
        g = ag.pairplot(self.df, diag_kind="hist", kind="reg")

        for ax in g.diag_axes:
            assert len(ax.patches)

        for i, j in zip(*np.triu_indices_from(g.axes, 1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

            assert len(ax.lines) == 1
            assert len(ax.collections) == 2

        for i, j in zip(*np.tril_indices_from(g.axes, -1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

            assert len(ax.lines) == 1
            assert len(ax.collections) == 2

        for i, j in zip(*np.diag_indices_from(g.axes)):
            ax = g.axes[i, j]
            assert len(ax.collections) == 0

    def test_pairplot_reg_hue(self):

        markers = ["o", "s", "d"]
        g = ag.pairplot(self.df, kind="reg", hue="a", markers=markers)

        ax = g.axes[-1, 0]
        c1 = ax.collections[0]
        c2 = ax.collections[2]

        assert not np.array_equal(c1.get_facecolor(), c2.get_facecolor())
        assert not np.array_equal(
            c1.get_paths()[0].vertices, c2.get_paths()[0].vertices,
        )

    def test_pairplot_diag_kde(self):

        vars = ["x", "y", "z"]
        g = ag.pairplot(self.df, diag_kind="kde")

        for ax in g.diag_axes:
            assert len(ax.collections) == 1

        for i, j in zip(*np.triu_indices_from(g.axes, 1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.tril_indices_from(g.axes, -1)):
            ax = g.axes[i, j]
            x_in = self.df[vars[j]]
            y_in = self.df[vars[i]]
            x_out, y_out = ax.collections[0].get_offsets().T
            npt.assert_array_equal(x_in, x_out)
            npt.assert_array_equal(y_in, y_out)

        for i, j in zip(*np.diag_indices_from(g.axes)):
            ax = g.axes[i, j]
            assert len(ax.collections) == 0

    def test_pairplot_kde(self):

        f, ax1 = plt.subplots()
        kdeplot(data=self.df, x="x", y="y", ax=ax1)

        g = ag.pairplot(self.df, kind="kde")
        ax2 = g.axes[1, 0]

        assert_plots_equal(ax1, ax2, labels=False)

    def test_pairplot_hist(self):

        f, ax1 = plt.subplots()
        histplot(data=self.df, x="x", y="y", ax=ax1)

        g = ag.pairplot(self.df, kind="hist")
        ax2 = g.axes[1, 0]

        assert_plots_equal(ax1, ax2, labels=False)

    @pytest.mark.skipif(_version_predates(mpl, "3.7.0"), reason="Matplotlib bug")
    def test_pairplot_markers(self):

        vars = ["x", "y", "z"]
        markers = ["o", "X", "s"]
        g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers)
        m1 = get_legend_handles(g._legend)[0].get_marker()
        m2 = get_legend_handles(g._legend)[1].get_marker()
        assert m1 != m2

        with pytest.warns(UserWarning):
            g = ag.pairplot(self.df, hue="a", vars=vars, markers=markers[:-2])

    def test_pairplot_column_multiindex(self):

        cols = pd.MultiIndex.from_arrays([["x", "y"], [1, 2]])
        df = self.df[["x", "y"]].set_axis(cols, axis=1)
        g = ag.pairplot(df)
        assert g.diag_vars == list(cols)

    def test_corner_despine(self):

        g = ag.PairGrid(self.df, corner=True, despine=False)
        g.map_diag(histplot)
        assert g.axes[0, 0].spines["top"].get_visible()

    def test_corner_set(self):

        g = ag.PairGrid(self.df, corner=True, despine=False)
        g.set(xlim=(0, 10))
        assert g.axes[-1, 0].get_xlim() == (0, 10)

    def test_legend(self):

        g1 = ag.pairplot(self.df, hue="a")
        assert isinstance(g1.legend, mpl.legend.Legend)

        g2 = ag.pairplot(self.df)
        assert g2.legend is None

    def test_tick_params(self):

        g = ag.PairGrid(self.df)
        color = "red"
        pad = 3
        g.tick_params(pad=pad, color=color)
        for ax in g.axes.flat:
            for axis in ["xaxis", "yaxis"]:
                for tick in getattr(ax, axis).get_major_ticks():
                    assert mpl.colors.same_color(tick.tick1line.get_color(), color)
                    assert mpl.colors.same_color(tick.tick2line.get_color(), color)
                    assert tick.get_pad() == pad

    @pytest.mark.skipif(
        condition=not hasattr(pd.api, "interchange"),
        reason="Tests behavior assuming support for dataframe interchange"
    )
    def test_data_interchange(self, mock_long_df, long_df):

        g = ag.PairGrid(mock_long_df, vars=["x", "y", "z"], hue="a")
        g.map(scatterplot)
        assert g.axes.shape == (3, 3)
        for ax in g.axes.flat:
            pts = ax.collections[0].get_offsets()
            assert len(pts) == len(long_df)


class TestJointGrid:

    rs = np.random.RandomState(sum(map(ord, "JointGrid")))
    x = rs.randn(100)
    y = rs.randn(100)
    x_na = x.copy()
    x_na[10] = np.nan
    x_na[20] = np.nan
    data = pd.DataFrame(dict(x=x, y=y, x_na=x_na))

    def test_margin_grid_from_lists(self):

        g = ag.JointGrid(x=self.x.tolist(), y=self.y.tolist())
        npt.assert_array_equal(g.x, self.x)
        npt.assert_array_equal(g.y, self.y)

    def test_margin_grid_from_arrays(self):

        g = ag.JointGrid(x=self.x, y=self.y)
        npt.assert_array_equal(g.x, self.x)
        npt.assert_array_equal(g.y, self.y)

    def test_margin_grid_from_series(self):

        g = ag.JointGrid(x=self.data.x, y=self.data.y)
        npt.assert_array_equal(g.x, self.x)
        npt.assert_array_equal(g.y, self.y)

    def test_margin_grid_from_dataframe(self):

        g = ag.JointGrid(x="x", y="y", data=self.data)
        npt.assert_array_equal(g.x, self.x)
        npt.assert_array_equal(g.y, self.y)

    def test_margin_grid_from_dataframe_bad_variable(self):

        with pytest.raises(ValueError):
            ag.JointGrid(x="x", y="bad_column", data=self.data)

    def test_margin_grid_axis_labels(self):

        g = ag.JointGrid(x="x", y="y", data=self.data)

        xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel()
        assert xlabel == "x"
        assert ylabel == "y"

        g.set_axis_labels("x variable", "y variable")
        xlabel, ylabel = g.ax_joint.get_xlabel(), g.ax_joint.get_ylabel()
        assert xlabel == "x variable"
        assert ylabel == "y variable"

    def test_dropna(self):

        g = ag.JointGrid(x="x_na", y="y", data=self.data, dropna=False)
        assert len(g.x) == len(self.x_na)

        g = ag.JointGrid(x="x_na", y="y", data=self.data, dropna=True)
        assert len(g.x) == pd.notnull(self.x_na).sum()

    def test_axlims(self):

        lim = (-3, 3)
        g = ag.JointGrid(x="x", y="y", data=self.data, xlim=lim, ylim=lim)

        assert g.ax_joint.get_xlim() == lim
        assert g.ax_joint.get_ylim() == lim

        assert g.ax_marg_x.get_xlim() == lim
        assert g.ax_marg_y.get_ylim() == lim

    def test_marginal_ticks(self):

        g = ag.JointGrid(marginal_ticks=False)
        assert not sum(t.get_visible() for t in g.ax_marg_x.get_yticklabels())
        assert not sum(t.get_visible() for t in g.ax_marg_y.get_xticklabels())

        g = ag.JointGrid(marginal_ticks=True)
        assert sum(t.get_visible() for t in g.ax_marg_x.get_yticklabels())
        assert sum(t.get_visible() for t in g.ax_marg_y.get_xticklabels())

    def test_bivariate_plot(self):

        g = ag.JointGrid(x="x", y="y", data=self.data)
        g.plot_joint(plt.plot)

        x, y = g.ax_joint.lines[0].get_xydata().T
        npt.assert_array_equal(x, self.x)
        npt.assert_array_equal(y, self.y)

    def test_univariate_plot(self):

        g = ag.JointGrid(x="x", y="x", data=self.data)
        g.plot_marginals(kdeplot)

        _, y1 = g.ax_marg_x.lines[0].get_xydata().T
        y2, _ = g.ax_marg_y.lines[0].get_xydata().T
        npt.assert_array_equal(y1, y2)

    def test_univariate_plot_distplot(self):

        bins = 10
        g = ag.JointGrid(x="x", y="x", data=self.data)
        with pytest.warns(UserWarning):
            g.plot_marginals(distplot, bins=bins)
        assert len(g.ax_marg_x.patches) == bins
        assert len(g.ax_marg_y.patches) == bins
        for x, y in zip(g.ax_marg_x.patches, g.ax_marg_y.patches):
            assert x.get_height() == y.get_width()

    def test_univariate_plot_matplotlib(self):

        bins = 10
        g = ag.JointGrid(x="x", y="x", data=self.data)
        g.plot_marginals(plt.hist, bins=bins)
        assert len(g.ax_marg_x.patches) == bins
        assert len(g.ax_marg_y.patches) == bins

    def test_plot(self):

        g = ag.JointGrid(x="x", y="x", data=self.data)
        g.plot(plt.plot, kdeplot)

        x, y = g.ax_joint.lines[0].get_xydata().T
        npt.assert_array_equal(x, self.x)
        npt.assert_array_equal(y, self.x)

        _, y1 = g.ax_marg_x.lines[0].get_xydata().T
        y2, _ = g.ax_marg_y.lines[0].get_xydata().T
        npt.assert_array_equal(y1, y2)

    def test_space(self):

        g = ag.JointGrid(x="x", y="y", data=self.data, space=0)

        joint_bounds = g.ax_joint.bbox.bounds
        marg_x_bounds = g.ax_marg_x.bbox.bounds
        marg_y_bounds = g.ax_marg_y.bbox.bounds

        assert joint_bounds[2] == marg_x_bounds[2]
        assert joint_bounds[3] == marg_y_bounds[3]

    @pytest.mark.parametrize(
        "as_vector", [True, False],
    )
    def test_hue(self, long_df, as_vector):

        if as_vector:
            data = None
            x, y, hue = long_df["x"], long_df["y"], long_df["a"]
        else:
            data = long_df
            x, y, hue = "x", "y", "a"

        g = ag.JointGrid(data=data, x=x, y=y, hue=hue)
        g.plot_joint(scatterplot)
        g.plot_marginals(histplot)

        g2 = ag.JointGrid()
        scatterplot(data=long_df, x=x, y=y, hue=hue, ax=g2.ax_joint)
        histplot(data=long_df, x=x, hue=hue, ax=g2.ax_marg_x)
        histplot(data=long_df, y=y, hue=hue, ax=g2.ax_marg_y)

        assert_plots_equal(g.ax_joint, g2.ax_joint)
        assert_plots_equal(g.ax_marg_x, g2.ax_marg_x, labels=False)
        assert_plots_equal(g.ax_marg_y, g2.ax_marg_y, labels=False)

    def test_refline(self):

        g = ag.JointGrid(x="x", y="y", data=self.data)
        g.plot(scatterplot, histplot)
        g.refline()
        assert not g.ax_joint.lines and not g.ax_marg_x.lines and not g.ax_marg_y.lines

        refx = refy = 0.5
        hline = np.array([[0, refy], [1, refy]])
        vline = np.array([[refx, 0], [refx, 1]])
        g.refline(x=refx, y=refy, joint=False, marginal=False)
        assert not g.ax_joint.lines and not g.ax_marg_x.lines and not g.ax_marg_y.lines

        g.refline(x=refx, y=refy)
        assert g.ax_joint.lines[0].get_color() == '.5'
        assert g.ax_joint.lines[0].get_linestyle() == '--'
        assert len(g.ax_joint.lines) == 2
        assert len(g.ax_marg_x.lines) == 1
        assert len(g.ax_marg_y.lines) == 1
        npt.assert_array_equal(g.ax_joint.lines[0].get_xydata(), vline)
        npt.assert_array_equal(g.ax_joint.lines[1].get_xydata(), hline)
        npt.assert_array_equal(g.ax_marg_x.lines[0].get_xydata(), vline)
        npt.assert_array_equal(g.ax_marg_y.lines[0].get_xydata(), hline)

        color, linestyle = 'red', '-'
        g.refline(x=refx, marginal=False, color=color, linestyle=linestyle)
        npt.assert_array_equal(g.ax_joint.lines[-1].get_xydata(), vline)
        assert g.ax_joint.lines[-1].get_color() == color
        assert g.ax_joint.lines[-1].get_linestyle() == linestyle
        assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)

        g.refline(x=refx, joint=False)
        npt.assert_array_equal(g.ax_marg_x.lines[-1].get_xydata(), vline)
        assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines) + 1

        g.refline(y=refy, joint=False)
        npt.assert_array_equal(g.ax_marg_y.lines[-1].get_xydata(), hline)
        assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)

        g.refline(y=refy, marginal=False)
        npt.assert_array_equal(g.ax_joint.lines[-1].get_xydata(), hline)
        assert len(g.ax_marg_x.lines) == len(g.ax_marg_y.lines)


class TestJointPlot:

    rs = np.random.RandomState(sum(map(ord, "jointplot")))
    x = rs.randn(100)
    y = rs.randn(100)
    data = pd.DataFrame(dict(x=x, y=y))

    def test_scatter(self):

        g = ag.jointplot(x="x", y="y", data=self.data)
        assert len(g.ax_joint.collections) == 1

        x, y = g.ax_joint.collections[0].get_offsets().T
        assert_array_equal(self.x, x)
        assert_array_equal(self.y, y)

        assert_array_almost_equal(
            [b.get_x() for b in g.ax_marg_x.patches],
            np.histogram_bin_edges(self.x, "auto")[:-1],
        )

        assert_array_almost_equal(
            [b.get_y() for b in g.ax_marg_y.patches],
            np.histogram_bin_edges(self.y, "auto")[:-1],
        )

    def test_scatter_hue(self, long_df):

        g1 = ag.jointplot(data=long_df, x="x", y="y", hue="a")

        g2 = ag.JointGrid()
        scatterplot(data=long_df, x="x", y="y", hue="a", ax=g2.ax_joint)
        kdeplot(data=long_df, x="x", hue="a", ax=g2.ax_marg_x, fill=True)
        kdeplot(data=long_df, y="y", hue="a", ax=g2.ax_marg_y, fill=True)

        assert_plots_equal(g1.ax_joint, g2.ax_joint)
        assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
        assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)

    def test_reg(self):

        g = ag.jointplot(x="x", y="y", data=self.data, kind="reg")
        assert len(g.ax_joint.collections) == 2

        x, y = g.ax_joint.collections[0].get_offsets().T
        assert_array_equal(self.x, x)
        assert_array_equal(self.y, y)

        assert g.ax_marg_x.patches
        assert g.ax_marg_y.patches

        assert g.ax_marg_x.lines
        assert g.ax_marg_y.lines

    def test_resid(self):

        g = ag.jointplot(x="x", y="y", data=self.data, kind="resid")
        assert g.ax_joint.collections
        assert g.ax_joint.lines
        assert not g.ax_marg_x.lines
        assert not g.ax_marg_y.lines

    def test_hist(self, long_df):

        bins = 3, 6
        g1 = ag.jointplot(data=long_df, x="x", y="y", kind="hist", bins=bins)

        g2 = ag.JointGrid()
        histplot(data=long_df, x="x", y="y", ax=g2.ax_joint, bins=bins)
        histplot(data=long_df, x="x", ax=g2.ax_marg_x, bins=bins[0])
        histplot(data=long_df, y="y", ax=g2.ax_marg_y, bins=bins[1])

        assert_plots_equal(g1.ax_joint, g2.ax_joint)
        assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
        assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)

    def test_hex(self):

        g = ag.jointplot(x="x", y="y", data=self.data, kind="hex")
        assert g.ax_joint.collections
        assert g.ax_marg_x.patches
        assert g.ax_marg_y.patches

    def test_kde(self, long_df):

        g1 = ag.jointplot(data=long_df, x="x", y="y", kind="kde")

        g2 = ag.JointGrid()
        kdeplot(data=long_df, x="x", y="y", ax=g2.ax_joint)
        kdeplot(data=long_df, x="x", ax=g2.ax_marg_x)
        kdeplot(data=long_df, y="y", ax=g2.ax_marg_y)

        assert_plots_equal(g1.ax_joint, g2.ax_joint)
        assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
        assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)

    def test_kde_hue(self, long_df):

        g1 = ag.jointplot(data=long_df, x="x", y="y", hue="a", kind="kde")

        g2 = ag.JointGrid()
        kdeplot(data=long_df, x="x", y="y", hue="a", ax=g2.ax_joint)
        kdeplot(data=long_df, x="x", hue="a", ax=g2.ax_marg_x)
        kdeplot(data=long_df, y="y", hue="a", ax=g2.ax_marg_y)

        assert_plots_equal(g1.ax_joint, g2.ax_joint)
        assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
        assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)

    def test_color(self):

        g = ag.jointplot(x="x", y="y", data=self.data, color="purple")

        scatter_color = g.ax_joint.collections[0].get_facecolor()
        assert_colors_equal(scatter_color, "purple")

        hist_color = g.ax_marg_x.patches[0].get_facecolor()[:3]
        assert_colors_equal(hist_color, "purple")

    def test_palette(self, long_df):

        kws = dict(data=long_df, hue="a", palette="Set2")

        g1 = ag.jointplot(x="x", y="y", **kws)

        g2 = ag.JointGrid()
        scatterplot(x="x", y="y", ax=g2.ax_joint, **kws)
        kdeplot(x="x", ax=g2.ax_marg_x, fill=True, **kws)
        kdeplot(y="y", ax=g2.ax_marg_y, fill=True, **kws)

        assert_plots_equal(g1.ax_joint, g2.ax_joint)
        assert_plots_equal(g1.ax_marg_x, g2.ax_marg_x, labels=False)
        assert_plots_equal(g1.ax_marg_y, g2.ax_marg_y, labels=False)

    def test_hex_customise(self):

        # test that default gridsize can be overridden
        g = ag.jointplot(x="x", y="y", data=self.data, kind="hex",
                         joint_kws=dict(gridsize=5))
        assert len(g.ax_joint.collections) == 1
        a = g.ax_joint.collections[0].get_array()
        assert a.shape[0] == 28  # 28 hexagons expected for gridsize 5

    def test_bad_kind(self):

        with pytest.raises(ValueError):
            ag.jointplot(x="x", y="y", data=self.data, kind="not_a_kind")

    def test_unsupported_hue_kind(self):

        for kind in ["reg", "resid", "hex"]:
            with pytest.raises(ValueError):
                ag.jointplot(x="x", y="y", hue="a", data=self.data, kind=kind)

    def test_leaky_dict(self):
        # Validate input dicts are unchanged by jointplot plotting function

        for kwarg in ("joint_kws", "marginal_kws"):
            for kind in ("hex", "kde", "resid", "reg", "scatter"):
                empty_dict = {}
                ag.jointplot(x="x", y="y", data=self.data, kind=kind,
                             **{kwarg: empty_dict})
                assert empty_dict == {}

    def test_distplot_kwarg_warning(self, long_df):

        with pytest.warns(UserWarning):
            g = ag.jointplot(data=long_df, x="x", y="y", marginal_kws=dict(rug=True))
        assert g.ax_marg_x.patches

    def test_ax_warning(self, long_df):

        ax = plt.gca()
        with pytest.warns(UserWarning):
            g = ag.jointplot(data=long_df, x="x", y="y", ax=ax)
        assert g.ax_joint.collections


================================================
FILE: tests/test_base.py
================================================
import itertools
import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt

import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal
from pandas.testing import assert_frame_equal

from seaborn.axisgrid import FacetGrid
from seaborn._compat import get_colormap, get_converter
from seaborn._base import (
    SemanticMapping,
    HueMapping,
    SizeMapping,
    StyleMapping,
    VectorPlotter,
    variable_type,
    infer_orient,
    unique_dashes,
    unique_markers,
    categorical_order,
)
from seaborn.utils import desaturate
from seaborn.palettes import color_palette


@pytest.fixture(params=[
    dict(x="x", y="y"),
    dict(x="t", y="y"),
    dict(x="a", y="y"),
    dict(x="x", y="y", hue="y"),
    dict(x="x", y="y", hue="a"),
    dict(x="x", y="y", size="a"),
    dict(x="x", y="y", style="a"),
    dict(x="x", y="y", hue="s"),
    dict(x="x", y="y", size="s"),
    dict(x="x", y="y", style="s"),
    dict(x="x", y="y", hue="a", style="a"),
    dict(x="x", y="y", hue="a", size="b", style="b"),
])
def long_variables(request):
    return request.param


class TestSemanticMapping:

    def test_call_lookup(self):

        m = SemanticMapping(VectorPlotter())
        lookup_table = dict(zip("abc", (1, 2, 3)))
        m.lookup_table = lookup_table
        for key, val in lookup_table.items():
            assert m(key) == val


class TestHueMapping:

    def test_plotter_default_init(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
        )
        assert not hasattr(p, "_hue_map")

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
        )
        assert isinstance(p._hue_map, HueMapping)
        assert p._hue_map.map_type == p.var_types["hue"]

    def test_plotter_customization(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
        )
        palette = "muted"
        hue_order = ["b", "a", "c"]
        p.map_hue(palette=palette, order=hue_order)
        assert p._hue_map.palette == palette
        assert p._hue_map.levels == hue_order

    def test_hue_map_null(self, flat_series, null_series):

        p = VectorPlotter(variables=dict(x=flat_series, hue=null_series))
        m = HueMapping(p)
        assert m.levels is None
        assert m.map_type is None
        assert m.palette is None
        assert m.cmap is None
        assert m.norm is None
        assert m.lookup_table is None

    def test_hue_map_categorical(self, wide_df, long_df):

        p = VectorPlotter(data=wide_df)
        m = HueMapping(p)
        assert m.levels == wide_df.columns.to_list()
        assert m.map_type == "categorical"
        assert m.cmap is None

        # Test named palette
        palette = "Blues"
        expected_colors = color_palette(palette, wide_df.shape[1])
        expected_lookup_table = dict(zip(wide_df.columns, expected_colors))
        m = HueMapping(p, palette=palette)
        assert m.palette == "Blues"
        assert m.lookup_table == expected_lookup_table

        # Test list palette
        palette = color_palette("Reds", wide_df.shape[1])
        expected_lookup_table = dict(zip(wide_df.columns, palette))
        m = HueMapping(p, palette=palette)
        assert m.palette == palette
        assert m.lookup_table == expected_lookup_table

        # Test dict palette
        colors = color_palette("Set1", 8)
        palette = dict(zip(wide_df.columns, colors))
        m = HueMapping(p, palette=palette)
        assert m.palette == palette
        assert m.lookup_table == palette

        # Test dict with missing keys
        palette = dict(zip(wide_df.columns[:-1], colors))
        with pytest.raises(ValueError):
            HueMapping(p, palette=palette)

        # Test list with wrong number of colors
        palette = colors[:-1]
        with pytest.warns(UserWarning):
            HueMapping(p, palette=palette)

        # Test hue order
        hue_order = ["a", "c", "d"]
        m = HueMapping(p, order=hue_order)
        assert m.levels == hue_order

        # Test long data
        p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="a"))
        m = HueMapping(p)
        assert m.levels == categorical_order(long_df["a"])
        assert m.map_type == "categorical"
        assert m.cmap is None

        # Test default palette
        m = HueMapping(p)
        hue_levels = categorical_order(long_df["a"])
        expected_colors = color_palette(n_colors=len(hue_levels))
        expected_lookup_table = dict(zip(hue_levels, expected_colors))
        assert m.lookup_table == expected_lookup_table

        # Test missing data
        m = HueMapping(p)
        assert m(np.nan) == (0, 0, 0, 0)

        # Test default palette with many levels
        x = y = np.arange(26)
        hue = pd.Series(list("abcdefghijklmnopqrstuvwxyz"))
        p = VectorPlotter(variables=dict(x=x, y=y, hue=hue))
        m = HueMapping(p)
        expected_colors = color_palette("husl", n_colors=len(hue))
        expected_lookup_table = dict(zip(hue, expected_colors))
        assert m.lookup_table == expected_lookup_table

        # Test binary data
        p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="c"))
        m = HueMapping(p)
        assert m.levels == [0, 1]
        assert m.map_type == "categorical"

        for val in [0, 1]:
            p = VectorPlotter(
                data=long_df[long_df["c"] == val],
                variables=dict(x="x", y="y", hue="c"),
            )
            m = HueMapping(p)
            assert m.levels == [val]
            assert m.map_type == "categorical"

        # Test Timestamp data
        p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="t"))
        m = HueMapping(p)
        assert m.levels == [pd.Timestamp(t) for t in long_df["t"].unique()]
        assert m.map_type == "datetime"

        # Test explicit categories
        p = VectorPlotter(data=long_df, variables=dict(x="x", hue="a_cat"))
        m = HueMapping(p)
        assert m.levels == long_df["a_cat"].cat.categories.to_list()
        assert m.map_type == "categorical"

        # Test numeric data with category type
        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="s_cat")
        )
        m = HueMapping(p)
        assert m.levels == categorical_order(long_df["s_cat"])
        assert m.map_type == "categorical"
        assert m.cmap is None

        # Test categorical palette specified for numeric data
        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="s")
        )
        palette = "deep"
        levels = categorical_order(long_df["s"])
        expected_colors = color_palette(palette, n_colors=len(levels))
        expected_lookup_table = dict(zip(levels, expected_colors))
        m = HueMapping(p, palette=palette)
        assert m.lookup_table == expected_lookup_table
        assert m.map_type == "categorical"

    def test_hue_map_numeric(self, long_df):

        vals = np.concatenate([np.linspace(0, 1, 256), [-.1, 1.1, np.nan]])

        # Test default colormap
        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="s")
        )
        hue_levels = list(np.sort(long_df["s"].unique()))
        m = HueMapping(p)
        assert m.levels == hue_levels
        assert m.map_type == "numeric"
        assert m.cmap.name == "seaborn_cubehelix"

        # Test named colormap
        palette = "Purples"
        m = HueMapping(p, palette=palette)
        assert_array_equal(m.cmap(vals), get_colormap(palette)(vals))

        # Test colormap object
        palette = get_colormap("Greens")
        m = HueMapping(p, palette=palette)
        assert_array_equal(m.cmap(vals), palette(vals))

        # Test cubehelix shorthand
        palette = "ch:2,0,light=.2"
        m = HueMapping(p, palette=palette)
        assert isinstance(m.cmap, mpl.colors.ListedColormap)

        # Test specified hue limits
        hue_norm = 1, 4
        m = HueMapping(p, norm=hue_norm)
        assert isinstance(m.norm, mpl.colors.Normalize)
        assert m.norm.vmin == hue_norm[0]
        assert m.norm.vmax == hue_norm[1]

        # Test Normalize object
        hue_norm = mpl.colors.PowerNorm(2, vmin=1, vmax=10)
        m = HueMapping(p, norm=hue_norm)
        assert m.norm is hue_norm

        # Test default colormap values
        hmin, hmax = p.plot_data["hue"].min(), p.plot_data["hue"].max()
        m = HueMapping(p)
        assert m.lookup_table[hmin] == pytest.approx(m.cmap(0.0))
        assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0))

        # Test specified colormap values
        hue_norm = hmin - 1, hmax - 1
        m = HueMapping(p, norm=hue_norm)
        norm_min = (hmin - hue_norm[0]) / (hue_norm[1] - hue_norm[0])
        assert m.lookup_table[hmin] == pytest.approx(m.cmap(norm_min))
        assert m.lookup_table[hmax] == pytest.approx(m.cmap(1.0))

        # Test list of colors
        hue_levels = list(np.sort(long_df["s"].unique()))
        palette = color_palette("Blues", len(hue_levels))
        m = HueMapping(p, palette=palette)
        assert m.lookup_table == dict(zip(hue_levels, palette))

        palette = color_palette("Blues", len(hue_levels) + 1)
        with pytest.warns(UserWarning):
            HueMapping(p, palette=palette)

        # Test dictionary of colors
        palette = dict(zip(hue_levels, color_palette("Reds")))
        m = HueMapping(p, palette=palette)
        assert m.lookup_table == palette

        palette.pop(hue_levels[0])
        with pytest.raises(ValueError):
            HueMapping(p, palette=palette)

        # Test invalid palette
        with pytest.raises(ValueError):
            HueMapping(p, palette="not a valid palette")

        # Test bad norm argument
        with pytest.raises(ValueError):
            HueMapping(p, norm="not a norm")

    def test_hue_map_without_hue_dataa(self, long_df):

        p = VectorPlotter(data=long_df, variables=dict(x="x", y="y"))
        with pytest.warns(UserWarning, match="Ignoring `palette`"):
            HueMapping(p, palette="viridis")

    def test_saturation(self, long_df):

        p = VectorPlotter(data=long_df, variables=dict(x="x", y="y", hue="a"))
        levels = categorical_order(long_df["a"])
        palette = color_palette("viridis", len(levels))
        saturation = 0.8

        m = HueMapping(p, palette=palette, saturation=saturation)
        for i, color in enumerate(m(levels)):
            assert mpl.colors.same_color(color, desaturate(palette[i], saturation))


class TestSizeMapping:

    def test_plotter_default_init(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
        )
        assert not hasattr(p, "_size_map")

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", size="a"),
        )
        assert isinstance(p._size_map, SizeMapping)
        assert p._size_map.map_type == p.var_types["size"]

    def test_plotter_customization(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", size="a"),
        )
        sizes = [1, 4, 2]
        size_order = ["b", "a", "c"]
        p.map_size(sizes=sizes, order=size_order)
        assert p._size_map.lookup_table == dict(zip(size_order, sizes))
        assert p._size_map.levels == size_order

    def test_size_map_null(self, flat_series, null_series):

        p = VectorPlotter(variables=dict(x=flat_series, size=null_series))
        m = HueMapping(p)
        assert m.levels is None
        assert m.map_type is None
        assert m.norm is None
        assert m.lookup_table is None

    def test_map_size_numeric(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", size="s"),
        )

        # Test default range of keys in the lookup table values
        m = SizeMapping(p)
        size_values = m.lookup_table.values()
        value_range = min(size_values), max(size_values)
        assert value_range == p._default_size_range

        # Test specified range of size values
        sizes = 1, 5
        m = SizeMapping(p, sizes=sizes)
        size_values = m.lookup_table.values()
        assert min(size_values), max(size_values) == sizes

        # Test size values with normalization range
        norm = 1, 10
        m = SizeMapping(p, sizes=sizes, norm=norm)
        normalize = mpl.colors.Normalize(*norm, clip=True)
        for key, val in m.lookup_table.items():
            assert val == sizes[0] + (sizes[1] - sizes[0]) * normalize(key)

        # Test size values with normalization object
        norm = mpl.colors.LogNorm(1, 10, clip=False)
        m = SizeMapping(p, sizes=sizes, norm=norm)
        assert m.norm.clip
        for key, val in m.lookup_table.items():
            assert val == sizes[0] + (sizes[1] - sizes[0]) * norm(key)

        # Test bad sizes argument
        with pytest.raises(ValueError):
            SizeMapping(p, sizes="bad_sizes")

        # Test bad sizes argument
        with pytest.raises(ValueError):
            SizeMapping(p, sizes=(1, 2, 3))

        # Test bad norm argument
        with pytest.raises(ValueError):
            SizeMapping(p, norm="bad_norm")

    def test_map_size_categorical(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", size="a"),
        )

        # Test specified size order
        levels = p.plot_data["size"].unique()
        sizes = [1, 4, 6]
        order = [levels[1], levels[2], levels[0]]
        m = SizeMapping(p, sizes=sizes, order=order)
        assert m.lookup_table == dict(zip(order, sizes))

        # Test list of sizes
        order = categorical_order(p.plot_data["size"])
        sizes = list(np.random.rand(len(levels)))
        m = SizeMapping(p, sizes=sizes)
        assert m.lookup_table == dict(zip(order, sizes))

        # Test dict of sizes
        sizes = dict(zip(levels, np.random.rand(len(levels))))
        m = SizeMapping(p, sizes=sizes)
        assert m.lookup_table == sizes

        # Test specified size range
        sizes = (2, 5)
        m = SizeMapping(p, sizes=sizes)
        values = np.linspace(*sizes, len(m.levels))[::-1]
        assert m.lookup_table == dict(zip(m.levels, values))

        # Test explicit categories
        p = VectorPlotter(data=long_df, variables=dict(x="x", size="a_cat"))
        m = SizeMapping(p)
        assert m.levels == long_df["a_cat"].cat.categories.to_list()
        assert m.map_type == "categorical"

        # Test sizes list with wrong length
        sizes = list(np.random.rand(len(levels) + 1))
        with pytest.warns(UserWarning):
            SizeMapping(p, sizes=sizes)

        # Test sizes dict with missing levels
        sizes = dict(zip(levels, np.random.rand(len(levels) - 1)))
        with pytest.raises(ValueError):
            SizeMapping(p, sizes=sizes)

        # Test bad sizes argument
        with pytest.raises(ValueError):
            SizeMapping(p, sizes="bad_size")

    def test_array_palette_deprecation(self, long_df):

        p = VectorPlotter(long_df, {"y": "y", "hue": "s"})
        pal = mpl.cm.Blues([.3, .5, .8])[:, :3]
        with pytest.warns(UserWarning, match="Numpy array is not a supported type"):
            m = HueMapping(p, pal)
        assert m.palette == pal.tolist()


class TestStyleMapping:

    def test_plotter_default_init(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
        )
        assert not hasattr(p, "_map_style")

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", style="a"),
        )
        assert isinstance(p._style_map, StyleMapping)

    def test_plotter_customization(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", style="a"),
        )
        markers = ["s", "p", "h"]
        style_order = ["b", "a", "c"]
        p.map_style(markers=markers, order=style_order)
        assert p._style_map.levels == style_order
        assert p._style_map(style_order, "marker") == markers

    def test_style_map_null(self, flat_series, null_series):

        p = VectorPlotter(variables=dict(x=flat_series, style=null_series))
        m = HueMapping(p)
        assert m.levels is None
        assert m.map_type is None
        assert m.lookup_table is None

    def test_map_style(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", style="a"),
        )

        # Test defaults
        m = StyleMapping(p, markers=True, dashes=True)

        n = len(m.levels)
        for key, dashes in zip(m.levels, unique_dashes(n)):
            assert m(key, "dashes") == dashes

        actual_marker_paths = {
            k: mpl.markers.MarkerStyle(m(k, "marker")).get_path()
            for k in m.levels
        }
        expected_marker_paths = {
            k: mpl.markers.MarkerStyle(m).get_path()
            for k, m in zip(m.levels, unique_markers(n))
        }
        assert actual_marker_paths == expected_marker_paths

        # Test lists
        markers, dashes = ["o", "s", "d"], [(1, 0), (1, 1), (2, 1, 3, 1)]
        m = StyleMapping(p, markers=markers, dashes=dashes)
        for key, mark, dash in zip(m.levels, markers, dashes):
            assert m(key, "marker") == mark
            assert m(key, "dashes") == dash

        # Test dicts
        markers = dict(zip(p.plot_data["style"].unique(), markers))
        dashes = dict(zip(p.plot_data["style"].unique(), dashes))
        m = StyleMapping(p, markers=markers, dashes=dashes)
        for key in m.levels:
            assert m(key, "marker") == markers[key]
            assert m(key, "dashes") == dashes[key]

        # Test explicit categories
        p = VectorPlotter(data=long_df, variables=dict(x="x", style="a_cat"))
        m = StyleMapping(p)
        assert m.levels == long_df["a_cat"].cat.categories.to_list()

        # Test style order with defaults
        order = p.plot_data["style"].unique()[[1, 2, 0]]
        m = StyleMapping(p, markers=True, dashes=True, order=order)
        n = len(order)
        for key, mark, dash in zip(order, unique_markers(n), unique_dashes(n)):
            assert m(key, "dashes") == dash
            assert m(key, "marker") == mark
            obj = mpl.markers.MarkerStyle(mark)
            path = obj.get_path().transformed(obj.get_transform())
            assert_array_equal(m(key, "path").vertices, path.vertices)

        # Test too many levels with style lists
        with pytest.warns(UserWarning):
            StyleMapping(p, markers=["o", "s"], dashes=False)

        with pytest.warns(UserWarning):
            StyleMapping(p, markers=False, dashes=[(2, 1)])

        # Test missing keys with style dicts
        markers, dashes = {"a": "o", "b": "s"}, False
        with pytest.raises(ValueError):
            StyleMapping(p, markers=markers, dashes=dashes)

        markers, dashes = False, {"a": (1, 0), "b": (2, 1)}
        with pytest.raises(ValueError):
            StyleMapping(p, markers=markers, dashes=dashes)

        # Test mixture of filled and unfilled markers
        markers, dashes = ["o", "x", "s"], None
        with pytest.raises(ValueError):
            StyleMapping(p, markers=markers, dashes=dashes)


class TestVectorPlotter:

    def test_flat_variables(self, flat_data):

        p = VectorPlotter()
        p.assign_variables(data=flat_data)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y"]
        assert len(p.plot_data) == len(flat_data)

        try:
            expected_x = flat_data.index
            expected_x_name = flat_data.index.name
        except AttributeError:
            expected_x = np.arange(len(flat_data))
            expected_x_name = None

        x = p.plot_data["x"]
        assert_array_equal(x, expected_x)

        expected_y = flat_data
        expected_y_name = getattr(flat_data, "name", None)

        y = p.plot_data["y"]
        assert_array_equal(y, expected_y)

        assert p.variables["x"] == expected_x_name
        assert p.variables["y"] == expected_y_name

    def test_long_df(self, long_df, long_variables):

        p = VectorPlotter()
        p.assign_variables(data=long_df, variables=long_variables)
        assert p.input_format == "long"
        assert p.variables == long_variables

        for key, val in long_variables.items():
            assert_array_equal(p.plot_data[key], long_df[val])

    def test_long_df_with_index(self, long_df, long_variables):

        p = VectorPlotter()
        p.assign_variables(
            data=long_df.set_index("a"),
            variables=long_variables,
        )
        assert p.input_format == "long"
        assert p.variables == long_variables

        for key, val in long_variables.items():
            assert_array_equal(p.plot_data[key], long_df[val])

    def test_long_df_with_multiindex(self, long_df, long_variables):

        p = VectorPlotter()
        p.assign_variables(
            data=long_df.set_index(["a", "x"]),
            variables=long_variables,
        )
        assert p.input_format == "long"
        assert p.variables == long_variables

        for key, val in long_variables.items():
            assert_array_equal(p.plot_data[key], long_df[val])

    def test_long_dict(self, long_dict, long_variables):

        p = VectorPlotter()
        p.assign_variables(
            data=long_dict,
            variables=long_variables,
        )
        assert p.input_format == "long"
        assert p.variables == long_variables

        for key, val in long_variables.items():
            assert_array_equal(p.plot_data[key], pd.Series(long_dict[val]))

    @pytest.mark.parametrize(
        "vector_type",
        ["series", "numpy", "list"],
    )
    def test_long_vectors(self, long_df, long_variables, vector_type):

        variables = {key: long_df[val] for key, val in long_variables.items()}
        if vector_type == "numpy":
            variables = {key: val.to_numpy() for key, val in variables.items()}
        elif vector_type == "list":
            variables = {key: val.to_list() for key, val in variables.items()}

        p = VectorPlotter()
        p.assign_variables(variables=variables)
        assert p.input_format == "long"

        assert list(p.variables) == list(long_variables)
        if vector_type == "series":
            assert p.variables == long_variables

        for key, val in long_variables.items():
            assert_array_equal(p.plot_data[key], long_df[val])

    def test_long_undefined_variables(self, long_df):

        p = VectorPlotter()

        with pytest.raises(ValueError):
            p.assign_variables(
                data=long_df, variables=dict(x="not_in_df"),
            )

        with pytest.raises(ValueError):
            p.assign_variables(
                data=long_df, variables=dict(x="x", y="not_in_df"),
            )

        with pytest.raises(ValueError):
            p.assign_variables(
                data=long_df, variables=dict(x="x", y="y", hue="not_in_df"),
            )

    @pytest.mark.parametrize(
        "arg", [[], np.array([]), pd.DataFrame()],
    )
    def test_empty_data_input(self, arg):

        p = VectorPlotter()
        p.assign_variables(data=arg)
        assert not p.variables

        if not isinstance(arg, pd.DataFrame):
            p = VectorPlotter()
            p.assign_variables(variables=dict(x=arg, y=arg))
            assert not p.variables

    def test_units(self, repeated_df):

        p = VectorPlotter()
        p.assign_variables(
            data=repeated_df,
            variables=dict(x="x", y="y", units="u"),
        )
        assert_array_equal(p.plot_data["units"], repeated_df["u"])

    @pytest.mark.parametrize("name", [3, 4.5])
    def test_long_numeric_name(self, long_df, name):

        long_df[name] = long_df["x"]
        p = VectorPlotter()
        p.assign_variables(data=long_df, variables={"x": name})
        assert_array_equal(p.plot_data["x"], long_df[name])
        assert p.variables["x"] == str(name)

    def test_long_hierarchical_index(self, rng):

        cols = pd.MultiIndex.from_product([["a"], ["x", "y"]])
        data = rng.uniform(size=(50, 2))
        df = pd.DataFrame(data, columns=cols)

        name = ("a", "y")
        var = "y"

        p = VectorPlotter()
        p.assign_variables(data=df, variables={var: name})
        assert_array_equal(p.plot_data[var], df[name])
        assert p.variables[var] == str(name)

    def test_long_scalar_and_data(self, long_df):

        val = 22
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": val})
        assert (p.plot_data["y"] == val).all()
        assert p.variables["y"] is None

    def test_wide_semantic_error(self, wide_df):

        err = "The following variable cannot be assigned with wide-form data: `hue`"
        with pytest.raises(ValueError, match=err):
            VectorPlotter(data=wide_df, variables={"hue": "a"})

    def test_long_unknown_error(self, long_df):

        err = "Could not interpret value `what` for `hue`"
        with pytest.raises(ValueError, match=err):
            VectorPlotter(data=long_df, variables={"x": "x", "hue": "what"})

    def test_long_unmatched_size_error(self, long_df, flat_array):

        err = "Length of ndarray vectors must match length of `data`"
        with pytest.raises(ValueError, match=err):
            VectorPlotter(data=long_df, variables={"x": "x", "hue": flat_array})

    def test_wide_categorical_columns(self, wide_df):

        wide_df.columns = pd.CategoricalIndex(wide_df.columns)
        p = VectorPlotter(data=wide_df)
        assert_array_equal(p.plot_data["hue"].unique(), ["a", "b", "c"])

    def test_iter_data_quantitites(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
        )
        out = p.iter_data("hue")
        assert len(list(out)) == 1

        var = "a"
        n_subsets = len(long_df[var].unique())

        semantics = ["hue", "size", "style"]
        for semantic in semantics:

            p = VectorPlotter(
                data=long_df,
                variables={"x": "x", "y": "y", semantic: var},
            )
            getattr(p, f"map_{semantic}")()
            out = p.iter_data(semantics)
            assert len(list(out)) == n_subsets

        var = "a"
        n_subsets = len(long_df[var].unique())

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var, style=var),
        )
        p.map_hue()
        p.map_style()
        out = p.iter_data(semantics)
        assert len(list(out)) == n_subsets

        # --

        out = p.iter_data(semantics, reverse=True)
        assert len(list(out)) == n_subsets

        # --

        var1, var2 = "a", "s"

        n_subsets = len(long_df[var1].unique())

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var1, style=var2),
        )
        p.map_hue()
        p.map_style()
        out = p.iter_data(["hue"])
        assert len(list(out)) == n_subsets

        n_subsets = len(set(list(map(tuple, long_df[[var1, var2]].values))))

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var1, style=var2),
        )
        p.map_hue()
        p.map_style()
        out = p.iter_data(semantics)
        assert len(list(out)) == n_subsets

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var1, size=var2, style=var1),
        )
        p.map_hue()
        p.map_size()
        p.map_style()
        out = p.iter_data(semantics)
        assert len(list(out)) == n_subsets

        # --

        var1, var2, var3 = "a", "s", "b"
        cols = [var1, var2, var3]
        n_subsets = len(set(list(map(tuple, long_df[cols].values))))

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var1, size=var2, style=var3),
        )
        p.map_hue()
        p.map_size()
        p.map_style()
        out = p.iter_data(semantics)
        assert len(list(out)) == n_subsets

    def test_iter_data_keys(self, long_df):

        semantics = ["hue", "size", "style"]

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
        )
        for sub_vars, _ in p.iter_data("hue"):
            assert sub_vars == {}

        # --

        var = "a"

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var),
        )
        for sub_vars, _ in p.iter_data("hue"):
            assert list(sub_vars) == ["hue"]
            assert sub_vars["hue"] in long_df[var].values

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", size=var),
        )
        for sub_vars, _ in p.iter_data("size"):
            assert list(sub_vars) == ["size"]
            assert sub_vars["size"] in long_df[var].values

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var, style=var),
        )
        for sub_vars, _ in p.iter_data(semantics):
            assert list(sub_vars) == ["hue", "style"]
            assert sub_vars["hue"] in long_df[var].values
            assert sub_vars["style"] in long_df[var].values
            assert sub_vars["hue"] == sub_vars["style"]

        var1, var2 = "a", "s"

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var1, size=var2),
        )
        for sub_vars, _ in p.iter_data(semantics):
            assert list(sub_vars) == ["hue", "size"]
            assert sub_vars["hue"] in long_df[var1].values
            assert sub_vars["size"] in long_df[var2].values

        semantics = ["hue", "col", "row"]
        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=var1, col=var2),
        )
        for sub_vars, _ in p.iter_data("hue"):
            assert list(sub_vars) == ["hue", "col"]
            assert sub_vars["hue"] in long_df[var1].values
            assert sub_vars["col"] in long_df[var2].values

    def test_iter_data_values(self, long_df):

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
        )

        p.sort = True
        _, sub_data = next(p.iter_data("hue"))
        assert_frame_equal(sub_data, p.plot_data)

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
        )

        for sub_vars, sub_data in p.iter_data("hue"):
            rows = p.plot_data["hue"] == sub_vars["hue"]
            assert_frame_equal(sub_data, p.plot_data[rows])

        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a", size="s"),
        )
        for sub_vars, sub_data in p.iter_data(["hue", "size"]):
            rows = p.plot_data["hue"] == sub_vars["hue"]
            rows &= p.plot_data["size"] == sub_vars["size"]
            assert_frame_equal(sub_data, p.plot_data[rows])

    def test_iter_data_reverse(self, long_df):

        reversed_order = categorical_order(long_df["a"])[::-1]
        p = VectorPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a")
        )
        iterator = p.iter_data("hue", reverse=True)
        for i, (sub_vars, _) in enumerate(iterator):
            assert sub_vars["hue"] == reversed_order[i]

    def test_iter_data_dropna(self, null_df):

        p = VectorPlotter(
            data=null_df,
            variables=dict(x="x", y="y", hue="a")
        )
        p.map_hue()
        for _, sub_df in p.iter_data("hue"):
            assert not sub_df.isna().any().any()

        some_missing = False
        for _, sub_df in p.iter_data("hue", dropna=False):
            some_missing |= sub_df.isna().any().any()
        assert some_missing

    def test_axis_labels(self, long_df):

        f, ax = plt.subplots()

        p = VectorPlotter(data=long_df, variables=dict(x="a"))

        p._add_axis_labels(ax)
        assert ax.get_xlabel() == "a"
        assert ax.get_ylabel() == ""
        ax.clear()

        p = VectorPlotter(data=long_df, variables=dict(y="a"))
        p._add_axis_labels(ax)
        assert ax.get_xlabel() == ""
        assert ax.get_ylabel() == "a"
        ax.clear()

        p = VectorPlotter(data=long_df, variables=dict(x="a"))

        p._add_axis_labels(ax, default_y="default")
        assert ax.get_xlabel() == "a"
        assert ax.get_ylabel() == "default"
        ax.clear()

        p = VectorPlotter(data=long_df, variables=dict(y="a"))
        p._add_axis_labels(ax, default_x="default", default_y="default")
        assert ax.get_xlabel() == "default"
        assert ax.get_ylabel() == "a"
        ax.clear()

        p = VectorPlotter(data=long_df, variables=dict(x="x", y="a"))
        ax.set(xlabel="existing", ylabel="also existing")
        p._add_axis_labels(ax)
        assert ax.get_xlabel() == "existing"
        assert ax.get_ylabel() == "also existing"

        f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)
        p = VectorPlotter(data=long_df, variables=dict(x="x", y="y"))

        p._add_axis_labels(ax1)
        p._add_axis_labels(ax2)

        assert ax1.get_xlabel() == "x"
        assert ax1.get_ylabel() == "y"
        assert ax1.yaxis.label.get_visible()

        assert ax2.get_xlabel() == "x"
        assert ax2.get_ylabel() == "y"
        assert not ax2.yaxis.label.get_visible()

    @pytest.mark.parametrize(
        "variables",
        [
            dict(x="x", y="y"),
            dict(x="x"),
            dict(y="y"),
            dict(x="t", y="y"),
            dict(x="x", y="a"),
        ]
    )
    def test_attach_basics(self, long_df, variables):

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables=variables)
        p._attach(ax)
        assert p.ax is ax

    def test_attach_disallowed(self, long_df):

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "a"})

        with pytest.raises(TypeError):
            p._attach(ax, allowed_types="numeric")

        with pytest.raises(TypeError):
            p._attach(ax, allowed_types=["datetime", "numeric"])

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x"})

        with pytest.raises(TypeError):
            p._attach(ax, allowed_types="categorical")

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})

        with pytest.raises(TypeError):
            p._attach(ax, allowed_types=["numeric", "categorical"])

    def test_attach_log_scale(self, long_df):

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x"})
        p._attach(ax, log_scale=True)
        assert ax.xaxis.get_scale() == "log"
        assert ax.yaxis.get_scale() == "linear"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x"})
        p._attach(ax, log_scale=2)
        assert ax.xaxis.get_scale() == "log"
        assert ax.yaxis.get_scale() == "linear"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"y": "y"})
        p._attach(ax, log_scale=True)
        assert ax.xaxis.get_scale() == "linear"
        assert ax.yaxis.get_scale() == "log"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
        p._attach(ax, log_scale=True)
        assert ax.xaxis.get_scale() == "log"
        assert ax.yaxis.get_scale() == "log"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
        p._attach(ax, log_scale=(True, False))
        assert ax.xaxis.get_scale() == "log"
        assert ax.yaxis.get_scale() == "linear"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
        p._attach(ax, log_scale=(False, 2))
        assert ax.xaxis.get_scale() == "linear"
        assert ax.yaxis.get_scale() == "log"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "a", "y": "y"})
        p._attach(ax, log_scale=True)
        assert ax.xaxis.get_scale() == "linear"
        assert ax.yaxis.get_scale() == "log"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})
        p._attach(ax, log_scale=True)
        assert ax.xaxis.get_scale() == "log"
        assert ax.yaxis.get_scale() == "linear"

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "a", "y": "b"})
        p._attach(ax, log_scale=True)
        assert ax.xaxis.get_scale() == "linear"
        assert ax.yaxis.get_scale() == "linear"

    def test_attach_converters(self, long_df):

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})
        p._attach(ax)
        assert get_converter(ax.xaxis) is None
        assert "Date" in get_converter(ax.yaxis).__class__.__name__

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "a", "y": "y"})
        p._attach(ax)
        assert "CategoryConverter" in get_converter(ax.xaxis).__class__.__name__
        assert get_converter(ax.yaxis) is None

    def test_attach_facets(self, long_df):

        g = FacetGrid(long_df, col="a")
        p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
        p._attach(g)
        assert p.ax is None
        assert p.facets == g

    def test_scale_transform_identity(self, long_df):

        _, ax = plt.subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x"})
        p._attach(ax)
        fwd, inv = p._get_scale_transforms("x")

        x = np.arange(1, 10)
        assert_array_equal(fwd(x), x)
        assert_array_equal(inv(x), x)

    def test_scale_transform_identity_facets(self, long_df):

        g = FacetGrid(long_df, col="a")
        p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
        p._attach(g)

        fwd, inv = p._get_scale_transforms("x")
        x = np.arange(1, 10)
        assert_array_equal(fwd(x), x)
        assert_array_equal(inv(x), x)

    def test_scale_transform_log(self, long_df):

        _, ax = plt.subplots()
        ax.set_xscale("log")
        p = VectorPlotter(data=long_df, variables={"x": "x"})
        p._attach(ax)

        fwd, inv = p._get_scale_transforms("x")
        x = np.arange(1, 4)
        assert_array_almost_equal(fwd(x), np.log10(x))
        assert_array_almost_equal(inv(x), 10 ** x)

    def test_scale_transform_facets(self, long_df):

        g = FacetGrid(long_df, col="a")
        p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
        p._attach(g)

        fwd, inv = p._get_scale_transforms("x")
        x = np.arange(4)
        assert_array_equal(inv(fwd(x)), x)

    def test_scale_transform_mixed_facets(self, long_df):

        g = FacetGrid(long_df, col="a", sharex=False)
        g.axes.flat[0].set_xscale("log")
        p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
        p._attach(g)

        err = "Cannot determine transform with mixed scales on faceted axes"
        with pytest.raises(RuntimeError, match=err):
            p._get_scale_transforms("x")

    def test_attach_shared_axes(self, long_df):

        g = FacetGrid(long_df)
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y"})
        p._attach(g)
        assert p.converters["x"].nunique() == 1

        g = FacetGrid(long_df, col="a")
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"})
        p._attach(g)
        assert p.converters["x"].nunique() == 1
        assert p.converters["y"].nunique() == 1

        g = FacetGrid(long_df, col="a", sharex=False)
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"})
        p._attach(g)
        assert p.converters["x"].nunique() == p.plot_data["col"].nunique()
        assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1
        assert p.converters["y"].nunique() == 1

        g = FacetGrid(long_df, col="a", sharex=False, col_wrap=2)
        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "y", "col": "a"})
        p._attach(g)
        assert p.converters["x"].nunique() == p.plot_data["col"].nunique()
        assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1
        assert p.converters["y"].nunique() == 1

        g = FacetGrid(long_df, col="a", row="b")
        p = VectorPlotter(
            data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
        )
        p._attach(g)
        assert p.converters["x"].nunique() == 1
        assert p.converters["y"].nunique() == 1

        g = FacetGrid(long_df, col="a", row="b", sharex=False)
        p = VectorPlotter(
            data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
        )
        p._attach(g)
        assert p.converters["x"].nunique() == len(g.axes.flat)
        assert p.converters["y"].nunique() == 1

        g = FacetGrid(long_df, col="a", row="b", sharex="col")
        p = VectorPlotter(
            data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
        )
        p._attach(g)
        assert p.converters["x"].nunique() == p.plot_data["col"].nunique()
        assert p.converters["x"].groupby(p.plot_data["col"]).nunique().max() == 1
        assert p.converters["y"].nunique() == 1

        g = FacetGrid(long_df, col="a", row="b", sharey="row")
        p = VectorPlotter(
            data=long_df, variables={"x": "x", "y": "y", "col": "a", "row": "b"},
        )
        p._attach(g)
        assert p.converters["x"].nunique() == 1
        assert p.converters["y"].nunique() == p.plot_data["row"].nunique()
        assert p.converters["y"].groupby(p.plot_data["row"]).nunique().max() == 1

    def test_get_axes_single(self, long_df):

        ax = plt.figure().subplots()
        p = VectorPlotter(data=long_df, variables={"x": "x", "hue": "a"})
        p._attach(ax)
        assert p._get_axes({"hue": "a"}) is ax

    def test_get_axes_facets(self, long_df):

        g = FacetGrid(long_df, col="a")
        p = VectorPlotter(data=long_df, variables={"x": "x", "col": "a"})
        p._attach(g)
        assert p._get_axes({"col": "b"}) is g.axes_dict["b"]

        g = FacetGrid(long_df, col="a", row="c")
        p = VectorPlotter(
            data=long_df, variables={"x": "x", "col": "a", "row": "c"}
        )
        p._attach(g)
        assert p._get_axes({"row": 1, "col": "b"}) is g.axes_dict[(1, "b")]

    def test_comp_data(self, long_df):

        p = VectorPlotter(data=long_df, variables={"x": "x", "y": "t"})

        # We have disabled this check for now, while it remains part of
        # the internal API, because it will require updating a number of tests
        # with pytest.raises(AttributeError):
        #     p.comp_data

        _, ax = plt.subplots()
        p._attach(ax)

        assert_array_equal(p.comp_data["x"], p.plot_data["x"])
        assert_array_equal(
            p.comp_data["y"], ax.yaxis.convert_units(p.plot_data["y"])
        )

        p = VectorPlotter(data=long_df, variables={"x": "a"})

        _, ax = plt.subplots()
        p._attach(ax)

        assert_array_equal(
            p.comp_data["x"], ax.xaxis.convert_units(p.plot_data["x"])
        )

    def test_comp_data_log(self, long_df):

        p = VectorPlotter(data=long_df, variables={"x": "z", "y": "y"})
        _, ax = plt.subplots()
        p._attach(ax, log_scale=(True, False))

        assert_array_equal(
            p.comp_data["x"], np.log10(p.plot_data["x"])
        )
        assert_array_equal(p.comp_data["y"], p.plot_data["y"])

    def test_comp_data_category_order(self):

        s = (pd.Series(["a", "b", "c", "a"], dtype="category")
             .cat.set_categories(["b", "c", "a"], ordered=True))

        p = VectorPlotter(variables={"x": s})
        _, ax = plt.subplots()
        p._attach(ax)
        assert_array_equal(
            p.comp_data["x"],
            [2, 0, 1, 2],
        )

    @pytest.fixture(
        params=itertools.product(
            [None, np.nan, pd.NA],
            ["numeric", "category", "datetime"],
        )
    )
    def comp_data_missing_fixture(self, request):

        # This fixture holds the logic for parameterizing
        # the following test (test_comp_data_missing)

        NA, var_type = request.param

        comp_data = [0, 1, np.nan, 2, np.nan, 1]
        if var_type == "numeric":
            orig_data = [0, 1, NA, 2, np.inf, 1]
        elif var_type == "category":
            orig_data = ["a", "b", NA, "c", pd.NA, "b"]
        elif var_type == "datetime":
            # Use 1-based numbers to avoid issue on matplotlib<3.2
            # Could simplify the test a bit when we roll off that version
            comp_data = [1, 2, np.nan, 3, np.nan, 2]
            numbers = [1, 2, 3, 2]

            orig_data = mpl.dates.num2date(numbers)
            orig_data.insert(2, NA)
            orig_data.insert(4, np.inf)

        return orig_data, comp_data

    def test_comp_data_missing(self, comp_data_missing_fixture):

        orig_data, comp_data = comp_data_missing_fixture
        p = VectorPlotter(variables={"x": orig_data})
        ax = plt.figure().subplots()
        p._attach(ax)
        assert_array_equal(p.comp_data["x"], comp_data)
        assert p.comp_data["x"].dtype == "float"

    def test_comp_data_duplicate_index(self):

        x = pd.Series([1, 2, 3, 4, 5], [1, 1, 1, 2, 2])
        p = VectorPlotter(variables={"x": x})
        ax = plt.figure().subplots()
        p._attach(ax)
        assert_array_equal(p.comp_data["x"], x)

    def test_comp_data_nullable_dtype(self):

        x = pd.Series([1, 2, 3, 4], dtype="Int64")
        p = VectorPlotter(variables={"x": x})
        ax = plt.figure().subplots()
        p._attach(ax)
        assert_array_equal(p.comp_data["x"], x)
        assert p.comp_data["x"].dtype == "float"

    def test_var_order(self, long_df):

        order = ["c", "b", "a"]
        for var in ["hue", "size", "style"]:
            p = VectorPlotter(data=long_df, variables={"x": "x", var: "a"})

            mapper = getattr(p, f"map_{var}")
            mapper(order=order)

            assert p.var_levels[var] == order

    def test_scale_native(self, long_df):

        p = VectorPlotter(data=long_df, variables={"x": "x"})
        with pytest.raises(NotImplementedError):
            p.scale_native("x")

    def test_scale_numeric(self, long_df):

        p = VectorPlotter(data=long_df, variables={"y": "y"})
        with pytest.raises(NotImplementedError):
            p.scale_numeric("y")

    def test_scale_datetime(self, long_df):

        p = VectorPlotter(data=long_df, variables={"x": "t"})
        with pytest.raises(NotImplementedError):
            p.scale_datetime("x")

    def test_scale_categorical(self, long_df):

        p = VectorPlotter(data=long_df, variables={"x": "x"})
        p.scale_categorical("y")
        assert p.variables["y"] is None
        assert p.var_types["y"] == "categorical"
        assert (p.plot_data["y"] == "").all()

        p = VectorPlotter(data=long_df, variables={"x": "s"})
        p.scale_categorical("x")
        assert p.var_types["x"] == "categorical"
        assert hasattr(p.plot_data["x"], "str")
        assert not p._var_ordered["x"]
        assert p.plot_data["x"].is_monotonic_increasing
        assert_array_equal(p.var_levels["x"], p.plot_data["x"].unique())

        p = VectorPlotter(data=long_df, variables={"x": "a"})
        p.scale_categorical("x")
        assert not p._var_ordered["x"]
        assert_array_equal(p.var_levels["x"], categorical_order(long_df["a"]))

        p = VectorPlotter(data=long_df, variables={"x": "a_cat"})
        p.scale_categorical("x")
        assert p._var_ordered["x"]
        assert_array_equal(p.var_levels["x"], categorical_order(long_df["a_cat"]))

        p = VectorPlotter(data=long_df, variables={"x": "a"})
        order = np.roll(long_df["a"].unique(), 1)
        p.scale_categorical("x", order=order)
        assert p._var_ordered["x"]
        assert_array_equal(p.var_levels["x"], order)

        p = VectorPlotter(data=long_df, variables={"x": "s"})
        p.scale_categorical("x", formatter=lambda x: f"{x:%}")
        assert p.plot_data["x"].str.endswith("%").all()
        assert all(s.endswith("%") for s in p.var_levels["x"])


class TestCoreFunc:

    def test_unique_dashes(self):

        n = 24
        dashes = unique_dashes(n)

        assert len(dashes) == n
        assert len(set(dashes)) == n
        assert dashes[0] == ""
        for spec in dashes[1:]:
            assert isinstance(spec, tuple)
            assert not len(spec) % 2

    def test_unique_markers(self):

        n = 24
        markers = unique_markers(n)

        assert len(markers) == n
        assert len(set(markers)) == n
        for m in markers:
            assert mpl.markers.MarkerStyle(m).is_filled()

    def test_variable_type(self):

        s = pd.Series([1., 2., 3.])
        assert variable_type(s) == "numeric"
        assert variable_type(s.astype(int)) == "numeric"
        assert variable_type(s.astype(object)) == "numeric"
        assert variable_type(s.to_numpy()) == "numeric"
        assert variable_type(s.to_list()) == "numeric"

        s = pd.Series([1, 2, 3, np.nan], dtype=object)
        assert variable_type(s) == "numeric"

        s = pd.Series([np.nan, np.nan])
        assert variable_type(s) == "numeric"

        s = pd.Series([pd.NA, pd.NA])
        assert variable_type(s) == "numeric"

        s = pd.Series([1, 2, pd.NA], dtype="Int64")
        assert variable_type(s) == "numeric"

        s = pd.Series(["1", "2", "3"])
        assert variable_type(s) == "categorical"
        assert variable_type(s.to_numpy()) == "categorical"
        assert variable_type(s.to_list()) == "categorical"

        # This should arguably be datmetime, but we don't currently handle it correctly
        # Test is mainly asserting that this doesn't fail on the boolean check.
        s = pd.timedelta_range(1, periods=3, freq="D").to_series()
        assert variable_type(s) == "categorical"

        s = pd.Series([True, False, False])
        assert variable_type(s) == "numeric"
        assert variable_type(s, boolean_type="categorical") == "categorical"
        s_cat = s.astype("category")
        assert variable_type(s_cat, boolean_type="categorical") == "categorical"
        assert variable_type(s_cat, boolean_type="numeric") == "categorical"

        s = pd.Series([pd.Timestamp(1), pd.Timestamp(2)])
        assert variable_type(s) == "datetime"
        assert variable_type(s.astype(object)) == "datetime"
        assert variable_type(s.to_numpy()) == "datetime"
        assert variable_type(s.to_list()) == "datetime"

    def test_infer_orient(self):

        nums = pd.Series(np.arange(6))
        cats = pd.Series(["a", "b"] * 3)
        dates = pd.date_range("1999-09-22", "2006-05-14", 6)

        assert infer_orient(cats, nums) == "x"
        assert infer_orient(nums, cats) == "y"

        assert infer_orient(cats, dates, require_numeric=False) == "x"
        assert infer_orient(dates, cats, require_numeric=False) == "y"

        assert infer_orient(nums, None) == "y"
        with pytest.warns(UserWarning, match="Vertical .+ `x`"):
            assert infer_orient(nums, None, "v") == "y"

        assert infer_orient(None, nums) == "x"
        with pytest.warns(UserWarning, match="Horizontal .+ `y`"):
            assert infer_orient(None, nums, "h") == "x"

        infer_orient(cats, None, require_numeric=False) == "y"
        with pytest.raises(TypeError, match="Horizontal .+ `x`"):
            infer_orient(cats, None)

        infer_orient(cats, None, require_numeric=False) == "x"
        with pytest.raises(TypeError, match="Vertical .+ `y`"):
            infer_orient(None, cats)

        assert infer_orient(nums, nums, "vert") == "x"
        assert infer_orient(nums, nums, "hori") == "y"

        assert infer_orient(cats, cats, "h", require_numeric=False) == "y"
        assert infer_orient(cats, cats, "v", require_numeric=False) == "x"
        assert infer_orient(cats, cats, require_numeric=False) == "x"

        with pytest.raises(TypeError, match="Vertical .+ `y`"):
            infer_orient(cats, cats, "x")
        with pytest.raises(TypeError, match="Horizontal .+ `x`"):
            infer_orient(cats, cats, "y")
        with pytest.raises(TypeError, match="Neither"):
            infer_orient(cats, cats)

        with pytest.raises(ValueError, match="`orient` must start with"):
            infer_orient(cats, nums, orient="bad value")

    def test_categorical_order(self):

        x = ["a", "c", "c", "b", "a", "d"]
        y = [3, 2, 5, 1, 4]
        order = ["a", "b", "c", "d"]

        out = categorical_order(x)
        assert out == ["a", "c", "b", "d"]

        out = categorical_order(x, order)
        assert out == order

        out = categorical_order(x, ["b", "a"])
        assert out == ["b", "a"]

        out = categorical_order(np.array(x))
        assert out == ["a", "c", "b", "d"]

        out = categorical_order(pd.Series(x))
        assert out == ["a", "c", "b", "d"]

        out = categorical_order(y)
        assert out == [1, 2, 3, 4, 5]

        out = categorical_order(np.array(y))
        assert out == [1, 2, 3, 4, 5]

        out = categorical_order(pd.Series(y))
        assert out == [1, 2, 3, 4, 5]

        x = pd.Categorical(x, order)
        out = categorical_order(x)
        assert out == list(x.categories)

        x = pd.Series(x)
        out = categorical_order(x)
        assert out == list(x.cat.categories)

        out = categorical_order(x, ["b", "a"])
        assert out == ["b", "a"]

        x = ["a", np.nan, "c", "c", "b", "a", "d"]
        out = categorical_order(x)
        assert out == ["a", "c", "b", "d"]


================================================
FILE: tests/test_categorical.py
================================================
import itertools
from functools import partial
import warnings

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import same_color, to_rgb, to_rgba

import pytest
from pytest import approx
from numpy.testing import (
    assert_array_equal,
    assert_array_less,
    assert_array_almost_equal,
)

from seaborn import categorical as cat

from seaborn._base import categorical_order
from seaborn._compat import get_colormap, get_legend_handles
from seaborn._testing import assert_plots_equal
from seaborn.categorical import (
    _CategoricalPlotter,
    Beeswarm,
    BoxPlotContainer,
    catplot,
    barplot,
    boxplot,
    boxenplot,
    countplot,
    pointplot,
    stripplot,
    swarmplot,
    violinplot,
)
from seaborn.palettes import color_palette
from seaborn.utils import _draw_figure, _version_predates, desaturate


PLOT_FUNCS = [
    catplot,
    barplot,
    boxplot,
    boxenplot,
    pointplot,
    stripplot,
    swarmplot,
    violinplot,
]


class TestCategoricalPlotterNew:

    @pytest.mark.parametrize(
        "func,kwargs",
        itertools.product(
            PLOT_FUNCS,
            [
                {"x": "x", "y": "a"},
                {"x": "a", "y": "y"},
                {"x": "y"},
                {"y": "x"},
            ],
        ),
    )
    def test_axis_labels(self, long_df, func, kwargs):

        func(data=long_df, **kwargs)

        ax = plt.gca()
        for axis in "xy":
            val = kwargs.get(axis, "")
            label_func = getattr(ax, f"get_{axis}label")
            assert label_func() == val

    @pytest.mark.parametrize("func", PLOT_FUNCS)
    def test_empty(self, func):

        func()
        ax = plt.gca()
        assert not ax.collections
        assert not ax.patches
        assert not ax.lines

        func(x=[], y=[])
        ax = plt.gca()
        assert not ax.collections
        assert not ax.patches
        assert not ax.lines

    def test_redundant_hue_backcompat(self, long_df):

        p = _CategoricalPlotter(
            data=long_df,
            variables={"x": "s", "y": "y"},
        )

        color = None
        palette = dict(zip(long_df["s"].unique(), color_palette()))
        hue_order = None

        palette, _ = p._hue_backcompat(color, palette, hue_order, force_hue=True)

        assert p.variables["hue"] == "s"
        assert_array_equal(p.plot_data["hue"], p.plot_data["x"])
        assert all(isinstance(k, str) for k in palette)


class SharedAxesLevelTests:

    def orient_indices(self, orient):
        pos_idx = ["x", "y"].index(orient)
        val_idx = ["y", "x"].index(orient)
        return pos_idx, val_idx

    @pytest.fixture
    def common_kws(self):
        return {}

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_labels_long(self, long_df, orient):

        depend = {"x": "y", "y": "x"}[orient]
        kws = {orient: "a", depend: "y", "hue": "b"}

        ax = self.func(long_df, **kws)

        # To populate texts; only needed on older matplotlibs
        _draw_figure(ax.figure)

        assert getattr(ax, f"get_{orient}label")() == kws[orient]
        assert getattr(ax, f"get_{depend}label")() == kws[depend]

        get_ori_labels = getattr(ax, f"get_{orient}ticklabels")
        ori_labels = [t.get_text() for t in get_ori_labels()]
        ori_levels = categorical_order(long_df[kws[orient]])
        assert ori_labels == ori_levels

        legend = ax.get_legend()
        assert legend.get_title().get_text() == kws["hue"]

        hue_labels = [t.get_text() for t in legend.texts]
        hue_levels = categorical_order(long_df[kws["hue"]])
        assert hue_labels == hue_levels

    def test_labels_wide(self, wide_df):

        wide_df = wide_df.rename_axis("cols", axis=1)
        ax = self.func(wide_df)

        # To populate texts; only needed on older matplotlibs
        _draw_figure(ax.figure)

        assert ax.get_xlabel() == wide_df.columns.name
        labels = [t.get_text() for t in ax.get_xticklabels()]
        for label, level in zip(labels, wide_df.columns):
            assert label == level

    def test_labels_hue_order(self, long_df):

        hue_var = "b"
        hue_order = categorical_order(long_df[hue_var])[::-1]
        ax = self.func(long_df, x="a", y="y", hue=hue_var, hue_order=hue_order)
        legend = ax.get_legend()
        hue_labels = [t.get_text() for t in legend.texts]
        assert hue_labels == hue_order

    def test_color(self, long_df, common_kws):
        common_kws.update(data=long_df, x="a", y="y")

        ax = plt.figure().subplots()
        self.func(ax=ax, **common_kws)
        assert self.get_last_color(ax) == to_rgba("C0")

        ax = plt.figure().subplots()
        self.func(ax=ax, **common_kws)
        self.func(ax=ax, **common_kws)
        assert self.get_last_color(ax) == to_rgba("C1")

        ax = plt.figure().subplots()
        self.func(color="C2", ax=ax, **common_kws)
        assert self.get_last_color(ax) == to_rgba("C2")

        ax = plt.figure().subplots()
        self.func(color="C3", ax=ax, **common_kws)
        assert self.get_last_color(ax) == to_rgba("C3")

    def test_two_calls(self):

        ax = plt.figure().subplots()
        self.func(x=["a", "b", "c"], y=[1, 2, 3], ax=ax)
        self.func(x=["e", "f"], y=[4, 5], ax=ax)
        assert ax.get_xlim() == (-.5, 4.5)

    def test_redundant_hue_legend(self, long_df):

        ax = self.func(long_df, x="a", y="y", hue="a")
        assert ax.get_legend() is None
        ax.clear()

        self.func(long_df, x="a", y="y", hue="a", legend=True)
        assert ax.get_legend() is not None

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_log_scale(self, long_df, orient):

        depvar = {"x": "y", "y": "x"}[orient]
        variables = {orient: "a", depvar: "z"}
        ax = self.func(long_df, **variables, log_scale=True)
        assert getattr(ax, f"get_{orient}scale")() == "linear"
        assert getattr(ax, f"get_{depvar}scale")() == "log"


class SharedScatterTests(SharedAxesLevelTests):
    """Tests functionality common to stripplot and swarmplot."""

    def get_last_color(self, ax):

        colors = ax.collections[-1].get_facecolors()
        unique_colors = np.unique(colors, axis=0)
        assert len(unique_colors) == 1
        return to_rgba(unique_colors.squeeze())

    # ------------------------------------------------------------------------------

    def test_color(self, long_df, common_kws):

        super().test_color(long_df, common_kws)

        ax = plt.figure().subplots()
        self.func(data=long_df, x="a", y="y", facecolor="C4", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C4")

        ax = plt.figure().subplots()
        self.func(data=long_df, x="a", y="y", fc="C5", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C5")

    def test_supplied_color_array(self, long_df):

        cmap = get_colormap("Blues")
        norm = mpl.colors.Normalize()
        colors = cmap(norm(long_df["y"].to_numpy()))

        keys = ["c", "fc", "facecolor", "facecolors"]

        for key in keys:

            ax = plt.figure().subplots()
            self.func(x=long_df["y"], **{key: colors})
            _draw_figure(ax.figure)
            assert_array_equal(ax.collections[0].get_facecolors(), colors)

        ax = plt.figure().subplots()
        self.func(x=long_df["y"], c=long_df["y"], cmap=cmap)
        _draw_figure(ax.figure)
        assert_array_equal(ax.collections[0].get_facecolors(), colors)

    def test_unfilled_marker(self, long_df):

        with warnings.catch_warnings():
            warnings.simplefilter("error", UserWarning)
            ax = self.func(long_df, x="y", y="a", marker="x", color="r")
            for points in ax.collections:
                assert same_color(points.get_facecolors().squeeze(), "r")
                assert same_color(points.get_edgecolors().squeeze(), "r")

    @pytest.mark.parametrize(
        "orient,data_type", [
            ("h", "dataframe"), ("h", "dict"),
            ("v", "dataframe"), ("v", "dict"),
            ("y", "dataframe"), ("y", "dict"),
            ("x", "dataframe"), ("x", "dict"),
        ]
    )
    def test_wide(self, wide_df, orient, data_type):

        if data_type == "dict":
            wide_df = {k: v.to_numpy() for k, v in wide_df.items()}

        ax = self.func(data=wide_df, orient=orient, color="C0")
        _draw_figure(ax.figure)

        cat_idx = 0 if orient in "vx" else 1
        val_idx = int(not cat_idx)

        axis_objs = ax.xaxis, ax.yaxis
        cat_axis = axis_objs[cat_idx]

        for i, label in enumerate(cat_axis.get_majorticklabels()):

            key = label.get_text()
            points = ax.collections[i]
            point_pos = points.get_offsets().T
            val_pos = point_pos[val_idx]
            cat_pos = point_pos[cat_idx]

            assert_array_equal(cat_pos.round(), i)
            assert_array_equal(val_pos, wide_df[key])

            for point_color in points.get_facecolors():
                assert tuple(point_color) == to_rgba("C0")

    @pytest.mark.parametrize("orient", ["h", "v"])
    def test_flat(self, flat_series, orient):

        ax = self.func(data=flat_series, orient=orient)
        _draw_figure(ax.figure)

        cat_idx = ["v", "h"].index(orient)
        val_idx = int(not cat_idx)

        points = ax.collections[0]
        pos = points.get_offsets().T

        assert_array_equal(pos[cat_idx].round(), np.zeros(len(flat_series)))
        assert_array_equal(pos[val_idx], flat_series)

    @pytest.mark.parametrize(
        "variables,orient",
        [
            # Order matters for assigning to x/y
            ({"cat": "a", "val": "y", "hue": None}, None),
            ({"val": "y", "cat": "a", "hue": None}, None),
            ({"cat": "a", "val": "y", "hue": "a"}, None),
            ({"val": "y", "cat": "a", "hue": "a"}, None),
            ({"cat": "a", "val": "y", "hue": "b"}, None),
            ({"val": "y", "cat": "a", "hue": "x"}, None),
            ({"cat": "s", "val": "y", "hue": None}, None),
            ({"val": "y", "cat": "s", "hue": None}, "h"),
            ({"cat": "a", "val": "b", "hue": None}, None),
            ({"val": "a", "cat": "b", "hue": None}, "h"),
            ({"cat": "a", "val": "t", "hue": None}, None),
            ({"val": "t", "cat": "a", "hue": None}, None),
            ({"cat": "d", "val": "y", "hue": None}, None),
            ({"val": "y", "cat": "d", "hue": None}, None),
            ({"cat": "a_cat", "val": "y", "hue": None}, None),
            ({"val": "y", "cat": "s_cat", "hue": None}, None),
        ],
    )
    def test_positions(self, long_df, variables, orient):

        cat_var = variables["cat"]
        val_var = variables["val"]
        hue_var = variables["hue"]
        var_names = list(variables.values())
        x_var, y_var, *_ = var_names

        ax = self.func(
            data=long_df, x=x_var, y=y_var, hue=hue_var, orient=orient,
        )

        _draw_figure(ax.figure)

        cat_idx = var_names.index(cat_var)
        val_idx = var_names.index(val_var)

        axis_objs = ax.xaxis, ax.yaxis
        cat_axis = axis_objs[cat_idx]
        val_axis = axis_objs[val_idx]

        cat_data = long_df[cat_var]
        cat_levels = categorical_order(cat_data)

        for i, label in enumerate(cat_levels):

            vals = long_df.loc[cat_data == label, val_var]

            points = ax.collections[i].get_offsets().T
            cat_pos = points[var_names.index(cat_var)]
            val_pos = points[var_names.index(val_var)]

            assert_array_equal(val_pos, val_axis.convert_units(vals))
            assert_array_equal(cat_pos.round(), i)
            assert 0 <= np.ptp(cat_pos) <= .8

            label = pd.Index([label]).astype(str)[0]
            assert cat_axis.get_majorticklabels()[i].get_text() == label

    @pytest.mark.parametrize(
        "variables",
        [
            # Order matters for assigning to x/y
            {"cat": "a", "val": "y", "hue": "b"},
            {"val": "y", "cat": "a", "hue": "c"},
            {"cat": "a", "val": "y", "hue": "f"},
        ],
    )
    def test_positions_dodged(self, long_df, variables):

        cat_var = variables["cat"]
        val_var = variables["val"]
        hue_var = variables["hue"]
        var_names = list(variables.values())
        x_var, y_var, *_ = var_names

        ax = self.func(
            data=long_df, x=x_var, y=y_var, hue=hue_var, dodge=True,
        )

        cat_vals = categorical_order(long_df[cat_var])
        hue_vals = categorical_order(long_df[hue_var])

        n_hue = len(hue_vals)
        offsets = np.linspace(0, .8, n_hue + 1)[:-1]
        offsets -= offsets.mean()
        nest_width = .8 / n_hue

        for i, cat_val in enumerate(cat_vals):
            for j, hue_val in enumerate(hue_vals):
                rows = (long_df[cat_var] == cat_val) & (long_df[hue_var] == hue_val)
                vals = long_df.loc[rows, val_var]

                points = ax.collections[n_hue * i + j].get_offsets().T
                cat_pos = points[var_names.index(cat_var)]
                val_pos = points[var_names.index(val_var)]

                if pd.api.types.is_datetime64_any_dtype(vals):
                    vals = mpl.dates.date2num(vals)

                assert_array_equal(val_pos, vals)

                assert_array_equal(cat_pos.round(), i)
                assert_array_equal((cat_pos - (i + offsets[j])).round() / nest_width, 0)
                assert 0 <= np.ptp(cat_pos) <= nest_width

    @pytest.mark.parametrize("cat_var", ["a", "s", "d"])
    def test_positions_unfixed(self, long_df, cat_var):

        long_df = long_df.sort_values(cat_var)

        kws = dict(size=.001)
        if "stripplot" in str(self.func):  # can't use __name__ with partial
            kws["jitter"] = False

        ax = self.func(data=long_df, x=cat_var, y="y", native_scale=True, **kws)

        for i, (cat_level, cat_data) in enumerate(long_df.groupby(cat_var)):

            points = ax.collections[i].get_offsets().T
            cat_pos = points[0]
            val_pos = points[1]

            assert_array_equal(val_pos, cat_data["y"])

            comp_level = np.squeeze(ax.xaxis.convert_units(cat_level)).item()
            assert_array_equal(cat_pos.round(), comp_level)

    @pytest.mark.parametrize(
        "x_type,order",
        [
            (str, None),
            (str, ["a", "b", "c"]),
            (str, ["c", "a"]),
            (str, ["a", "b", "c", "d"]),
            (int, None),
            (int, [3, 1, 2]),
            (int, [3, 1]),
            (int, [1, 2, 3, 4]),
            (int, ["3", "1", "2"]),
        ]
    )
    def test_order(self, x_type, order):

        if x_type is str:
            x = ["b", "a", "c"]
        else:
            x = [2, 1, 3]
        y = [1, 2, 3]

        ax = self.func(x=x, y=y, order=order)
        _draw_figure(ax.figure)

        if order is None:
            order = x
            if x_type is int:
                order = np.sort(order)

        assert len(ax.collections) == len(order)
        tick_labels = ax.xaxis.get_majorticklabels()

        assert ax.get_xlim()[1] == (len(order) - .5)

        for i, points in enumerate(ax.collections):
            cat = order[i]
            assert tick_labels[i].get_text() == str(cat)

            positions = points.get_offsets()
            if x_type(cat) in x:
                val = y[x.index(x_type(cat))]
                assert positions[0, 1] == val
            else:
                assert not positions.size

    @pytest.mark.parametrize("hue_var", ["a", "b"])
    def test_hue_categorical(self, long_df, hue_var):

        cat_var = "b"

        hue_levels = categorical_order(long_df[hue_var])
        cat_levels = categorical_order(long_df[cat_var])

        pal_name = "muted"
        palette = dict(zip(hue_levels, color_palette(pal_name)))
        ax = self.func(data=long_df, x=cat_var, y="y", hue=hue_var, palette=pal_name)

        for i, level in enumerate(cat_levels):

            sub_df = long_df[long_df[cat_var] == level]
            point_hues = sub_df[hue_var]

            points = ax.collections[i]
            point_colors = points.get_facecolors()

            assert len(point_hues) == len(point_colors)

            for hue, color in zip(point_hues, point_colors):
                assert tuple(color) == to_rgba(palette[hue])

    @pytest.mark.parametrize("hue_var", ["a", "b"])
    def test_hue_dodged(self, long_df, hue_var):

        ax = self.func(data=long_df, x="y", y="a", hue=hue_var, dodge=True)
        colors = color_palette(n_colors=long_df[hue_var].nunique())
        collections = iter(ax.collections)

        # Slightly awkward logic to handle challenges of how the artists work.
        # e.g. there are empty scatter collections but the because facecolors
        # for the empty collections will return the default scatter color
        while colors:
            points = next(collections)
            if points.get_offsets().any():
                face_color = tuple(points.get_facecolors()[0])
                expected_color = to_rgba(colors.pop(0))
                assert face_color == expected_color

    @pytest.mark.parametrize(
        "val_var,val_col,hue_col",
        list(itertools.product(["x", "y"], ["b", "y", "t"], [None, "a"])),
    )
    def test_single(self, long_df, val_var, val_col, hue_col):

        var_kws = {val_var: val_col, "hue": hue_col}
        ax = self.func(data=long_df, **var_kws)
        _draw_figure(ax.figure)

        axis_vars = ["x", "y"]
        val_idx = axis_vars.index(val_var)
        cat_idx = int(not val_idx)
        cat_var = axis_vars[cat_idx]

        cat_axis = getattr(ax, f"{cat_var}axis")
        val_axis = getattr(ax, f"{val_var}axis")

        points = ax.collections[0]
        point_pos = points.get_offsets().T
        cat_pos = point_pos[cat_idx]
        val_pos = point_pos[val_idx]

        assert_array_equal(cat_pos.round(), 0)
        assert cat_pos.max() <= .4
        assert cat_pos.min() >= -.4

        num_vals = val_axis.convert_units(long_df[val_col])
        assert_array_equal(val_pos, num_vals)

        if hue_col is not None:
            palette = dict(zip(
                categorical_order(long_df[hue_col]), color_palette()
            ))

        facecolors = points.get_facecolors()
        for i, color in enumerate(facecolors):
            if hue_col is None:
                assert tuple(color) == to_rgba("C0")
            else:
                hue_level = long_df.loc[i, hue_col]
                expected_color = palette[hue_level]
                assert tuple(color) == to_rgba(expected_color)

        ticklabels = cat_axis.get_majorticklabels()
        assert len(ticklabels) == 1
        assert not ticklabels[0].get_text()

    def test_attributes(self, long_df):

        kwargs = dict(
            size=2,
            linewidth=1,
            edgecolor="C2",
        )

        ax = self.func(x=long_df["y"], **kwargs)
        points, = ax.collections

        assert points.get_sizes().item() == kwargs["size"] ** 2
        assert points.get_linewidths().item() == kwargs["linewidth"]
        assert tuple(points.get_edgecolors().squeeze()) == to_rgba(kwargs["edgecolor"])

    def test_three_points(self):

        x = np.arange(3)
        ax = self.func(x=x)
        for point_color in ax.collections[0].get_facecolor():
            assert tuple(point_color) == to_rgba("C0")

    def test_legend_categorical(self, long_df):

        ax = self.func(data=long_df, x="y", y="a", hue="b")
        legend_texts = [t.get_text() for t in ax.legend_.texts]
        expected = categorical_order(long_df["b"])
        assert legend_texts == expected

    def test_legend_numeric(self, long_df):

        ax = self.func(data=long_df, x="y", y="a", hue="z")
        vals = [float(t.get_text()) for t in ax.legend_.texts]
        assert (vals[1] - vals[0]) == approx(vals[2] - vals[1])

    def test_legend_attributes(self, long_df):

        kws = {"edgecolor": "r", "linewidth": 1}
        ax = self.func(data=long_df, x="x", y="y", hue="a", **kws)
        for pt in get_legend_handles(ax.get_legend()):
            assert same_color(pt.get_markeredgecolor(), kws["edgecolor"])
            assert pt.get_markeredgewidth() == kws["linewidth"]

    def test_legend_disabled(self, long_df):

        ax = self.func(data=long_df, x="y", y="a", hue="b", legend=False)
        assert ax.legend_ is None

    def test_palette_from_color_deprecation(self, long_df):

        color = (.9, .4, .5)
        hex_color = mpl.colors.to_hex(color)

        hue_var = "a"
        n_hue = long_df[hue_var].nunique()
        palette = color_palette(f"dark:{hex_color}", n_hue)

        with pytest.warns(FutureWarning, match="Setting a gradient palette"):
            ax = self.func(data=long_df, x="z", hue=hue_var, color=color)

        points = ax.collections[0]
        for point_color in points.get_facecolors():
            assert to_rgb(point_color) in palette

    def test_palette_with_hue_deprecation(self, long_df):
        palette = "Blues"
        with pytest.warns(FutureWarning, match="Passing `palette` without"):
            ax = self.func(data=long_df, x="a", y=long_df["y"], palette=palette)
        strips = ax.collections
        colors = color_palette(palette, len(strips))
        for strip, color in zip(strips, colors):
            assert same_color(strip.get_facecolor()[0], color)

    def test_log_scale(self):

        x = [1, 10, 100, 1000]

        ax = plt.figure().subplots()
        ax.set_xscale("log")
        self.func(x=x)
        vals = ax.collections[0].get_offsets()[:, 0]
        assert_array_equal(x, vals)

        y = [1, 2, 3, 4]

        ax = plt.figure().subplots()
        ax.set_xscale("log")
        self.func(x=x, y=y, native_scale=True)
        for i, point in enumerate(ax.collections):
            val = point.get_offsets()[0, 0]
            assert val == approx(x[i])

        x = y = np.ones(100)

        ax = plt.figure().subplots()
        ax.set_yscale("log")
        self.func(x=x, y=y, orient="h", native_scale=True)
        cat_points = ax.collections[0].get_offsets().copy()[:, 1]
        assert np.ptp(np.log10(cat_points)) <= .8

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data="wide"),
            dict(data="wide", orient="h"),
            dict(data="long", x="x", color="C3"),
            dict(data="long", y="y", hue="a", jitter=False),
            dict(data="long", x="a", y="y", hue="z", edgecolor="w", linewidth=.5),
            dict(data="long", x="a", y="y", hue="z", edgecolor="auto", linewidth=.5),
            dict(data="long", x="a_cat", y="y", hue="z"),
            dict(data="long", x="y", y="s", hue="c", orient="h", dodge=True),
            dict(data="long", x="s", y="y", hue="c", native_scale=True),
        ]
    )
    def test_vs_catplot(self, long_df, wide_df, kwargs):

        kwargs = kwargs.copy()
        if kwargs["data"] == "long":
            kwargs["data"] = long_df
        elif kwargs["data"] == "wide":
            kwargs["data"] = wide_df

        try:
            name = self.func.__name__[:-4]
        except AttributeError:
            name = self.func.func.__name__[:-4]
        if name == "swarm":
            kwargs.pop("jitter", None)

        np.random.seed(0)  # for jitter
        ax = self.func(**kwargs)

        np.random.seed(0)
        g = catplot(**kwargs, kind=name)

        assert_plots_equal(ax, g.ax)

    def test_empty_palette(self):
        self.func(x=[], y=[], hue=[], palette=[])


class SharedAggTests(SharedAxesLevelTests):

    def test_labels_flat(self):

        ind = pd.Index(["a", "b", "c"], name="x")
        ser = pd.Series([1, 2, 3], ind, name="y")

        ax = self.func(ser)

        # To populate texts; only needed on older matplotlibs
        _draw_figure(ax.figure)

        assert ax.get_xlabel() == ind.name
        assert ax.get_ylabel() == ser.name
        labels = [t.get_text() for t in ax.get_xticklabels()]
        for label, level in zip(labels, ind):
            assert label == level


class SharedPatchArtistTests:

    @pytest.mark.parametrize("fill", [True, False])
    def test_legend_fill(self, long_df, fill):

        palette = color_palette()
        ax = self.func(
            long_df, x="x", y="y", hue="a",
            saturation=1, linecolor="k", fill=fill,
        )
        for i, patch in enumerate(get_legend_handles(ax.get_legend())):
            fc = patch.get_facecolor()
            ec = patch.get_edgecolor()
            if fill:
                assert same_color(fc, palette[i])
                assert same_color(ec, "k")
            else:
                assert fc == (0, 0, 0, 0)
                assert same_color(ec, palette[i])

    def test_legend_attributes(self, long_df):

        ax = self.func(long_df, x="x", y="y", hue="a", linewidth=3)
        for patch in get_legend_handles(ax.get_legend()):
            assert patch.get_linewidth() == 3


class TestStripPlot(SharedScatterTests):

    func = staticmethod(stripplot)

    def test_jitter_unfixed(self, long_df):

        ax1, ax2 = plt.figure().subplots(2)
        kws = dict(data=long_df, x="y", orient="h", native_scale=True)

        np.random.seed(0)
        stripplot(**kws, y="s", ax=ax1)

        np.random.seed(0)
        stripplot(**kws, y=long_df["s"] * 2, ax=ax2)

        p1 = ax1.collections[0].get_offsets()[1]
        p2 = ax2.collections[0].get_offsets()[1]

        assert p2.std() > p1.std()

    @pytest.mark.parametrize(
        "orient,jitter",
        itertools.product(["v", "h"], [True, .1]),
    )
    def test_jitter(self, long_df, orient, jitter):

        cat_var, val_var = "a", "y"
        if orient == "x":
            x_var, y_var = cat_var, val_var
            cat_idx, val_idx = 0, 1
        else:
            x_var, y_var = val_var, cat_var
            cat_idx, val_idx = 1, 0

        cat_vals = categorical_order(long_df[cat_var])

        ax = stripplot(
            data=long_df, x=x_var, y=y_var, jitter=jitter,
        )

        if jitter is True:
            jitter_range = .4
        else:
            jitter_range = 2 * jitter

        for i, level in enumerate(cat_vals):

            vals = long_df.loc[long_df[cat_var] == level, val_var]
            points = ax.collections[i].get_offsets().T
            cat_points = points[cat_idx]
            val_points = points[val_idx]

            assert_array_equal(val_points, vals)
            assert np.std(cat_points) > 0
            assert np.ptp(cat_points) <= jitter_range


class TestSwarmPlot(SharedScatterTests):

    func = staticmethod(partial(swarmplot, warn_thresh=1))


class TestBoxPlot(SharedAxesLevelTests, SharedPatchArtistTests):

    func = staticmethod(boxplot)

    @pytest.fixture
    def common_kws(self):
        return {"saturation": 1}

    def get_last_color(self, ax):

        colors = [b.get_facecolor() for b in ax.containers[-1].boxes]
        unique_colors = np.unique(colors, axis=0)
        assert len(unique_colors) == 1
        return to_rgba(unique_colors.squeeze())

    def get_box_verts(self, box):

        path = box.get_path()
        visible_codes = [mpl.path.Path.MOVETO, mpl.path.Path.LINETO]
        visible = np.isin(path.codes, visible_codes)
        return path.vertices[visible].T

    def check_box(self, bxp, data, orient, pos, width=0.8):

        pos_idx, val_idx = self.orient_indices(orient)

        p25, p50, p75 = np.percentile(data, [25, 50, 75])

        box = self.get_box_verts(bxp.box)
        assert box[val_idx].min() == approx(p25, 1e-3)
        assert box[val_idx].max() == approx(p75, 1e-3)
        assert box[pos_idx].min() == approx(pos - width / 2)
        assert box[pos_idx].max() == approx(pos + width / 2)

        med = bxp.median.get_xydata().T
        assert np.allclose(med[val_idx], (p50, p50), rtol=1e-3)
        assert np.allclose(med[pos_idx], (pos - width / 2, pos + width / 2))

    def check_whiskers(self, bxp, data, orient, pos, capsize=0.4, whis=1.5):

        pos_idx, val_idx = self.orient_indices(orient)

        whis_lo = bxp.whiskers[0].get_xydata().T
        whis_hi = bxp.whiskers[1].get_xydata().T
        caps_lo = bxp.caps[0].get_xydata().T
        caps_hi = bxp.caps[1].get_xydata().T
        fliers = bxp.fliers.get_xydata().T

        p25, p75 = np.percentile(data, [25, 75])
        iqr = p75 - p25

        adj_lo = data[data >= (p25 - iqr * whis)].min()
        adj_hi = data[data <= (p75 + iqr * whis)].max()

        assert whis_lo[val_idx].max() == approx(p25, 1e-3)
        assert whis_lo[val_idx].min() == approx(adj_lo)
        assert np.allclose(whis_lo[pos_idx], (pos, pos))
        assert np.allclose(caps_lo[val_idx], (adj_lo, adj_lo))
        assert np.allclose(caps_lo[pos_idx], (pos - capsize / 2, pos + capsize / 2))

        assert whis_hi[val_idx].min() == approx(p75, 1e-3)
        assert whis_hi[val_idx].max() == approx(adj_hi)
        assert np.allclose(whis_hi[pos_idx], (pos, pos))
        assert np.allclose(caps_hi[val_idx], (adj_hi, adj_hi))
        assert np.allclose(caps_hi[pos_idx], (pos - capsize / 2, pos + capsize / 2))

        flier_data = data[(data < adj_lo) | (data > adj_hi)]
        assert sorted(fliers[val_idx]) == sorted(flier_data)
        assert np.allclose(fliers[pos_idx], pos)

    @pytest.mark.parametrize("orient,col", [("x", "y"), ("y", "z")])
    def test_single_var(self, long_df, orient, col):

        var = {"x": "y", "y": "x"}[orient]
        ax = boxplot(long_df, **{var: col})
        bxp = ax.containers[0][0]
        self.check_box(bxp, long_df[col], orient, 0)
        self.check_whiskers(bxp, long_df[col], orient, 0)

    @pytest.mark.parametrize("orient,col", [(None, "x"), ("x", "y"), ("y", "z")])
    def test_vector_data(self, long_df, orient, col):

        ax = boxplot(long_df[col], orient=orient)
        orient = "x" if orient is None else orient
        bxp = ax.containers[0][0]
        self.check_box(bxp, long_df[col], orient, 0)
        self.check_whiskers(bxp, long_df[col], orient, 0)

    @pytest.mark.parametrize("orient", ["h", "v"])
    def test_wide_data(self, wide_df, orient):

        orient = {"h": "y", "v": "x"}[orient]
        ax = boxplot(wide_df, orient=orient, color="C0")
        for i, bxp in enumerate(ax.containers):
            col = wide_df.columns[i]
            self.check_box(bxp[i], wide_df[col], orient, i)
            self.check_whiskers(bxp[i], wide_df[col], orient, i)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_grouped(self, long_df, orient):

        value = {"x": "y", "y": "x"}[orient]
        ax = boxplot(long_df, **{orient: "a", value: "z"})
        bxp, = ax.containers
        levels = categorical_order(long_df["a"])
        for i, level in enumerate(levels):
            data = long_df.loc[long_df["a"] == level, "z"]
            self.check_box(bxp[i], data, orient, i)
            self.check_whiskers(bxp[i], data, orient, i)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_hue_grouped(self, long_df, orient):

        value = {"x": "y", "y": "x"}[orient]
        ax = boxplot(long_df, hue="c", **{orient: "a", value: "z"})
        for i, hue_level in enumerate(categorical_order(long_df["c"])):
            bxp = ax.containers[i]
            for j, level in enumerate(categorical_order(long_df["a"])):
                rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = j + [-.2, +.2][i]
                width, capsize = 0.4, 0.2
                self.check_box(bxp[j], data, orient, pos, width)
                self.check_whiskers(bxp[j], data, orient, pos, capsize)

    def test_hue_not_dodged(self, long_df):

        levels = categorical_order(long_df["b"])
        hue = long_df["b"].isin(levels[:2])
        ax = boxplot(long_df, x="b", y="z", hue=hue)
        bxps = ax.containers
        for i, level in enumerate(levels):
            idx = int(i < 2)
            data = long_df.loc[long_df["b"] == level, "z"]
            self.check_box(bxps[idx][i % 2], data, "x", i)
            self.check_whiskers(bxps[idx][i % 2], data, "x", i)

    def test_dodge_native_scale(self, long_df):

        centers = categorical_order(long_df["s"])
        hue_levels = categorical_order(long_df["c"])
        spacing = min(np.diff(centers))
        width = 0.8 * spacing / len(hue_levels)
        offset = width / len(hue_levels)
        ax = boxplot(long_df, x="s", y="z", hue="c", native_scale=True)
        for i, hue_level in enumerate(hue_levels):
            bxp = ax.containers[i]
            for j, center in enumerate(centers):
                rows = (long_df["s"] == center) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = center + [-offset, +offset][i]
                self.check_box(bxp[j], data, "x", pos, width)
                self.check_whiskers(bxp[j], data, "x", pos, width / 2)

    def test_dodge_native_scale_log(self, long_df):

        pos = 10 ** long_df["s"]
        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        boxplot(long_df, x=pos, y="z", hue="c", native_scale=True, ax=ax)
        widths = []
        for bxp in ax.containers:
            for box in bxp.boxes:
                coords = np.log10(box.get_path().vertices.T[0])
                widths.append(np.ptp(coords))
        assert np.std(widths) == approx(0)

    def test_dodge_without_hue(self, long_df):

        ax = boxplot(long_df, x="a", y="y", dodge=True)
        bxp, = ax.containers
        levels = categorical_order(long_df["a"])
        for i, level in enumerate(levels):
            data = long_df.loc[long_df["a"] == level, "y"]
            self.check_box(bxp[i], data, "x", i)
            self.check_whiskers(bxp[i], data, "x", i)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_log_data_scale(self, long_df, orient):

        var = {"x": "y", "y": "x"}[orient]
        s = long_df["z"]
        ax = mpl.figure.Figure().subplots()
        getattr(ax, f"set_{var}scale")("log")
        boxplot(**{var: s}, whis=np.inf, ax=ax)
        bxp = ax.containers[0][0]
        self.check_box(bxp, s, orient, 0)
        self.check_whiskers(bxp, s, orient, 0, whis=np.inf)

    def test_color(self, long_df):

        color = "#123456"
        ax = boxplot(long_df, x="a", y="y", color=color, saturation=1)
        for box in ax.containers[0].boxes:
            assert same_color(box.get_facecolor(), color)

    def test_wide_data_multicolored(self, wide_df):

        ax = boxplot(wide_df)
        assert len(ax.containers) == wide_df.shape[1]

    def test_wide_data_single_color(self, wide_df):

        ax = boxplot(wide_df, color="C1", saturation=1)
        assert len(ax.containers) == 1
        for box in ax.containers[0].boxes:
            assert same_color(box.get_facecolor(), "C1")

    def test_hue_colors(self, long_df):

        ax = boxplot(long_df, x="a", y="y", hue="b", saturation=1)
        for i, bxp in enumerate(ax.containers):
            for box in bxp.boxes:
                assert same_color(box.get_facecolor(), f"C{i}")

    def test_linecolor(self, long_df):

        color = "#778815"
        ax = boxplot(long_df, x="a", y="y", linecolor=color)
        bxp = ax.containers[0]
        for line in [*bxp.medians, *bxp.whiskers, *bxp.caps]:
            assert same_color(line.get_color(), color)
        for box in bxp.boxes:
            assert same_color(box.get_edgecolor(), color)
        for flier in bxp.fliers:
            assert same_color(flier.get_markeredgecolor(), color)

    def test_linecolor_gray_warning(self, long_df):

        with pytest.warns(FutureWarning, match="Use \"auto\" to set automatic"):
            boxplot(long_df, x="y", linecolor="gray")

    def test_saturation(self, long_df):

        color = "#8912b0"
        ax = boxplot(long_df["x"], color=color, saturation=.5)
        box = ax.containers[0].boxes[0]
        assert np.allclose(box.get_facecolor()[:3], desaturate(color, 0.5))

    def test_linewidth(self, long_df):

        width = 5
        ax = boxplot(long_df, x="a", y="y", linewidth=width)
        bxp = ax.containers[0]
        for line in [*bxp.boxes, *bxp.medians, *bxp.whiskers, *bxp.caps]:
            assert line.get_linewidth() == width

    def test_fill(self, long_df):

        color = "#459900"
        ax = boxplot(x=long_df["z"], fill=False, color=color)
        bxp = ax.containers[0]
        assert isinstance(bxp.boxes[0], mpl.lines.Line2D)
        for line in [*bxp.boxes, *bxp.medians, *bxp.whiskers, *bxp.caps]:
            assert same_color(line.get_color(), color)

    @pytest.mark.parametrize("notch_param", ["notch", "shownotches"])
    def test_notch(self, long_df, notch_param):

        ax = boxplot(x=long_df["z"], **{notch_param: True})
        verts = ax.containers[0].boxes[0].get_path().vertices
        assert len(verts) == 12

    def test_whis(self, long_df):

        data = long_df["z"]
        ax = boxplot(x=data, whis=2)
        bxp = ax.containers[0][0]
        self.check_whiskers(bxp, data, "y", 0, whis=2)

    def test_gap(self, long_df):

        ax = boxplot(long_df, x="a", y="z", hue="c", gap=.1)
        for i, hue_level in enumerate(categorical_order(long_df["c"])):
            bxp = ax.containers[i]
            for j, level in enumerate(categorical_order(long_df["a"])):
                rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = j + [-.2, +.2][i]
                width = 0.9 * 0.4
                self.check_box(bxp[j], data, "x", pos, width)

    def test_prop_dicts(self, long_df):

        prop_dicts = dict(
            boxprops=dict(linewidth=3),
            medianprops=dict(color=".1"),
            whiskerprops=dict(linestyle="--"),
            capprops=dict(solid_capstyle="butt"),
            flierprops=dict(marker="s"),
        )
        attr_map = dict(box="boxes", flier="fliers")
        ax = boxplot(long_df, x="a", y="z", hue="c", **prop_dicts)
        for bxp in ax.containers:
            for element in ["box", "median", "whisker", "cap", "flier"]:
                attr = attr_map.get(element, f"{element}s")
                for artist in getattr(bxp, attr):
                    for k, v in prop_dicts[f"{element}props"].items():
                        assert plt.getp(artist, k) == v

    def test_showfliers(self, long_df):

        ax = boxplot(long_df["x"], showfliers=False)
        assert not ax.containers[0].fliers

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data="wide"),
            dict(data="wide", orient="h"),
            dict(data="flat"),
            dict(data="long", x="a", y="y"),
            dict(data=None, x="a", y="y"),
            dict(data="long", x="a", y="y", hue="a"),
            dict(data=None, x="a", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="b"),
            dict(data=None, x="s", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="s"),
            dict(data="null", x="a", y="y", hue="a"),
            dict(data="long", x="s", y="y", hue="a", native_scale=True),
            dict(data="long", x="d", y="y", hue="a", native_scale=True),
            dict(data="null", x="a", y="y", hue="b", fill=False, gap=.2),
            dict(data="null", x="a", y="y", whis=1, showfliers=False),
            dict(data="null", x="a", y="y", linecolor="r", linewidth=5),
            dict(data="null", x="a", y="y", shownotches=True, showcaps=False),
        ]
    )
    def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):

        if kwargs["data"] == "long":
            kwargs["data"] = long_df
        elif kwargs["data"] == "wide":
            kwargs["data"] = wide_df
        elif kwargs["data"] == "flat":
            kwargs["data"] = flat_series
        elif kwargs["data"] == "null":
            kwargs["data"] = null_df
        elif kwargs["data"] is None:
            for var in ["x", "y", "hue"]:
                if var in kwargs:
                    kwargs[var] = long_df[kwargs[var]]

        ax = boxplot(**kwargs)
        g = catplot(**kwargs, kind="box")

        assert_plots_equal(ax, g.ax)


class TestBoxenPlot(SharedAxesLevelTests, SharedPatchArtistTests):

    func = staticmethod(boxenplot)

    @pytest.fixture
    def common_kws(self):
        return {"saturation": 1}

    def get_last_color(self, ax):

        fcs = ax.collections[-2].get_facecolors()
        return to_rgba(fcs[len(fcs) // 2])

    def get_box_width(self, path, orient="x"):

        verts = path.vertices.T
        idx = ["y", "x"].index(orient)
        return np.ptp(verts[idx])

    def check_boxen(self, patches, data, orient, pos, width=0.8):

        pos_idx, val_idx = self.orient_indices(orient)
        verts = np.stack([v.vertices for v in patches.get_paths()], 1).T

        assert verts[pos_idx].min().round(4) >= np.round(pos - width / 2, 4)
        assert verts[pos_idx].max().round(4) <= np.round(pos + width / 2, 4)
        assert np.isin(
            np.percentile(data, [25, 75]).round(4), verts[val_idx].round(4).flat
        ).all()
        assert_array_equal(verts[val_idx, 1:, 0], verts[val_idx, :-1, 2])

    @pytest.mark.parametrize("orient,col", [("x", "y"), ("y", "z")])
    def test_single_var(self, long_df, orient, col):

        var = {"x": "y", "y": "x"}[orient]
        ax = boxenplot(long_df, **{var: col})
        patches = ax.collections[0]
        self.check_boxen(patches, long_df[col], orient, 0)

    @pytest.mark.parametrize("orient,col", [(None, "x"), ("x", "y"), ("y", "z")])
    def test_vector_data(self, long_df, orient, col):

        orient = "x" if orient is None else orient
        ax = boxenplot(long_df[col], orient=orient)
        patches = ax.collections[0]
        self.check_boxen(patches, long_df[col], orient, 0)

    @pytest.mark.parametrize("orient", ["h", "v"])
    def test_wide_data(self, wide_df, orient):

        orient = {"h": "y", "v": "x"}[orient]
        ax = boxenplot(wide_df, orient=orient)
        collections = ax.findobj(mpl.collections.PatchCollection)
        for i, patches in enumerate(collections):
            col = wide_df.columns[i]
            self.check_boxen(patches, wide_df[col], orient, i)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_grouped(self, long_df, orient):

        value = {"x": "y", "y": "x"}[orient]
        ax = boxenplot(long_df, **{orient: "a", value: "z"})
        levels = categorical_order(long_df["a"])
        collections = ax.findobj(mpl.collections.PatchCollection)
        for i, level in enumerate(levels):
            data = long_df.loc[long_df["a"] == level, "z"]
            self.check_boxen(collections[i], data, orient, i)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_hue_grouped(self, long_df, orient):

        value = {"x": "y", "y": "x"}[orient]
        ax = boxenplot(long_df, hue="c", **{orient: "a", value: "z"})
        collections = iter(ax.findobj(mpl.collections.PatchCollection))
        for i, level in enumerate(categorical_order(long_df["a"])):
            for j, hue_level in enumerate(categorical_order(long_df["c"])):
                rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = i + [-.2, +.2][j]
                width = 0.4
                self.check_boxen(next(collections), data, orient, pos, width)

    def test_dodge_native_scale(self, long_df):

        centers = categorical_order(long_df["s"])
        hue_levels = categorical_order(long_df["c"])
        spacing = min(np.diff(centers))
        width = 0.8 * spacing / len(hue_levels)
        offset = width / len(hue_levels)
        ax = boxenplot(long_df, x="s", y="z", hue="c", native_scale=True)
        collections = iter(ax.findobj(mpl.collections.PatchCollection))
        for center in centers:
            for i, hue_level in enumerate(hue_levels):
                rows = (long_df["s"] == center) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = center + [-offset, +offset][i]
                self.check_boxen(next(collections), data, "x", pos, width)

    def test_color(self, long_df):

        color = "#123456"
        ax = boxenplot(long_df, x="a", y="y", color=color, saturation=1)
        collections = ax.findobj(mpl.collections.PatchCollection)
        for patches in collections:
            fcs = patches.get_facecolors()
            assert same_color(fcs[len(fcs) // 2], color)

    def test_hue_colors(self, long_df):

        ax = boxenplot(long_df, x="a", y="y", hue="b", saturation=1)
        n_levels = long_df["b"].nunique()
        collections = ax.findobj(mpl.collections.PatchCollection)
        for i, patches in enumerate(collections):
            fcs = patches.get_facecolors()
            assert same_color(fcs[len(fcs) // 2], f"C{i % n_levels}")

    def test_linecolor(self, long_df):

        color = "#669913"
        ax = boxenplot(long_df, x="a", y="y", linecolor=color)
        for patches in ax.findobj(mpl.collections.PatchCollection):
            assert same_color(patches.get_edgecolor(), color)

    def test_linewidth(self, long_df):

        width = 5
        ax = boxenplot(long_df, x="a", y="y", linewidth=width)
        for patches in ax.findobj(mpl.collections.PatchCollection):
            assert patches.get_linewidth() == width

    def test_saturation(self, long_df):

        color = "#8912b0"
        ax = boxenplot(long_df["x"], color=color, saturation=.5)
        fcs = ax.collections[0].get_facecolors()
        assert np.allclose(fcs[len(fcs) // 2, :3], desaturate(color, 0.5))

    def test_gap(self, long_df):

        ax1, ax2 = mpl.figure.Figure().subplots(2)
        boxenplot(long_df, x="a", y="y", hue="s", ax=ax1)
        boxenplot(long_df, x="a", y="y", hue="s", gap=.2, ax=ax2)
        c1 = ax1.findobj(mpl.collections.PatchCollection)
        c2 = ax2.findobj(mpl.collections.PatchCollection)
        for p1, p2 in zip(c1, c2):
            w1 = np.ptp(p1.get_paths()[0].vertices[:, 0])
            w2 = np.ptp(p2.get_paths()[0].vertices[:, 0])
            assert (w2 / w1) == pytest.approx(0.8)

    def test_fill(self, long_df):

        ax = boxenplot(long_df, x="a", y="y", hue="s", fill=False)
        for c in ax.findobj(mpl.collections.PatchCollection):
            assert not c.get_facecolors().size

    def test_k_depth_int(self, rng):

        x = rng.normal(0, 1, 10_000)
        ax = boxenplot(x, k_depth=(k := 8))
        assert len(ax.collections[0].get_paths()) == (k * 2 - 1)

    def test_k_depth_full(self, rng):

        x = rng.normal(0, 1, 10_000)
        ax = boxenplot(x=x, k_depth="full")
        paths = ax.collections[0].get_paths()
        assert len(paths) == 2 * int(np.log2(x.size)) + 1
        verts = np.concatenate([p.vertices for p in paths]).T
        assert verts[0].min() == x.min()
        assert verts[0].max() == x.max()
        assert not ax.collections[1].get_offsets().size

    def test_trust_alpha(self, rng):

        x = rng.normal(0, 1, 10_000)
        ax = boxenplot(x, k_depth="trustworthy", trust_alpha=.1)
        boxenplot(x, k_depth="trustworthy", trust_alpha=.001, ax=ax)
        cs = ax.findobj(mpl.collections.PatchCollection)
        assert len(cs[0].get_paths()) > len(cs[1].get_paths())

    def test_outlier_prop(self, rng):

        x = rng.normal(0, 1, 10_000)
        ax = boxenplot(x, k_depth="proportion", outlier_prop=.001)
        boxenplot(x, k_depth="proportion", outlier_prop=.1, ax=ax)
        cs = ax.findobj(mpl.collections.PatchCollection)
        assert len(cs[0].get_paths()) > len(cs[1].get_paths())

    def test_exponential_width_method(self, rng):

        x = rng.normal(0, 1, 10_000)
        ax = boxenplot(x=x, width_method="exponential")
        c = ax.findobj(mpl.collections.PatchCollection)[0]
        ws = [self.get_box_width(p) for p in c.get_paths()]
        assert (ws[1] / ws[0]) == pytest.approx(ws[2] / ws[1])

    def test_linear_width_method(self, rng):

        x = rng.normal(0, 1, 10_000)
        ax = boxenplot(x=x, width_method="linear")
        c = ax.findobj(mpl.collections.PatchCollection)[0]
        ws = [self.get_box_width(p) for p in c.get_paths()]
        assert (ws[1] - ws[0]) == pytest.approx(ws[2] - ws[1])

    def test_area_width_method(self, rng):

        x = rng.uniform(0, 1, 10_000)
        ax = boxenplot(x=x, width_method="area", k_depth=2)
        ps = ax.findobj(mpl.collections.PatchCollection)[0].get_paths()
        ws = [self.get_box_width(p) for p in ps]
        assert np.greater(ws, 0.7).all()

    def test_box_kws(self, long_df):

        ax = boxenplot(long_df, x="a", y="y", box_kws={"linewidth": (lw := 7.1)})
        for c in ax.findobj(mpl.collections.PatchCollection):
            assert c.get_linewidths() == lw

    def test_line_kws(self, long_df):

        ax = boxenplot(long_df, x="a", y="y", line_kws={"linewidth": (lw := 6.2)})
        for line in ax.lines:
            assert line.get_linewidth() == lw

    def test_flier_kws(self, long_df):

        ax = boxenplot(long_df, x="a", y="y", flier_kws={"marker": (marker := "X")})
        expected = mpl.markers.MarkerStyle(marker).get_path().vertices
        for c in ax.findobj(mpl.collections.PathCollection):
            assert_array_equal(c.get_paths()[0].vertices, expected)

    def test_k_depth_checks(self, long_df):

        with pytest.raises(ValueError, match="The value for `k_depth`"):
            boxenplot(x=long_df["y"], k_depth="auto")

        with pytest.raises(TypeError, match="The `k_depth` parameter"):
            boxenplot(x=long_df["y"], k_depth=(1, 2))

    def test_width_method_check(self, long_df):

        with pytest.raises(ValueError, match="The value for `width_method`"):
            boxenplot(x=long_df["y"], width_method="uniform")

    def test_scale_deprecation(self, long_df):

        with pytest.warns(FutureWarning, match="The `scale` parameter has been"):
            boxenplot(x=long_df["y"], scale="linear")

        with pytest.warns(FutureWarning, match=".+result for 'area' will appear"):
            boxenplot(x=long_df["y"], scale="area")

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data="wide"),
            dict(data="wide", orient="h"),
            dict(data="flat"),
            dict(data="long", x="a", y="y"),
            dict(data=None, x="a", y="y"),
            dict(data="long", x="a", y="y", hue="a"),
            dict(data=None, x="a", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="b"),
            dict(data=None, x="s", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="s", showfliers=False),
            dict(data="null", x="a", y="y", hue="a", saturation=.5),
            dict(data="long", x="s", y="y", hue="a", native_scale=True),
            dict(data="long", x="d", y="y", hue="a", native_scale=True),
            dict(data="null", x="a", y="y", hue="b", fill=False, gap=.2),
            dict(data="null", x="a", y="y", linecolor="r", linewidth=5),
            dict(data="long", x="a", y="y", k_depth="trustworthy", trust_alpha=.1),
            dict(data="long", x="a", y="y", k_depth="proportion", outlier_prop=.1),
            dict(data="long", x="a", y="z", width_method="area"),
            dict(data="long", x="a", y="z", box_kws={"alpha": .2}, alpha=.4)
        ]
    )
    def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):

        if kwargs["data"] == "long":
            kwargs["data"] = long_df
        elif kwargs["data"] == "wide":
            kwargs["data"] = wide_df
        elif kwargs["data"] == "flat":
            kwargs["data"] = flat_series
        elif kwargs["data"] == "null":
            kwargs["data"] = null_df
        elif kwargs["data"] is None:
            for var in ["x", "y", "hue"]:
                if var in kwargs:
                    kwargs[var] = long_df[kwargs[var]]

        ax = boxenplot(**kwargs)
        g = catplot(**kwargs, kind="boxen")

        assert_plots_equal(ax, g.ax)


class TestViolinPlot(SharedAxesLevelTests, SharedPatchArtistTests):

    func = staticmethod(violinplot)

    @pytest.fixture
    def common_kws(self):
        return {"saturation": 1}

    def get_last_color(self, ax):

        color = ax.collections[-1].get_facecolor()
        return to_rgba(color)

    def violin_width(self, poly, orient="x"):

        idx, _ = self.orient_indices(orient)
        return np.ptp(poly.get_paths()[0].vertices[:, idx])

    def check_violin(self, poly, data, orient, pos, width=0.8):

        pos_idx, val_idx = self.orient_indices(orient)
        verts = poly.get_paths()[0].vertices.T

        assert verts[pos_idx].min() >= (pos - width / 2)
        assert verts[pos_idx].max() <= (pos + width / 2)
        # Assumes violin was computed with cut=0
        assert verts[val_idx].min() == approx(data.min())
        assert verts[val_idx].max() == approx(data.max())

    @pytest.mark.parametrize("orient,col", [("x", "y"), ("y", "z")])
    def test_single_var(self, long_df, orient, col):

        var = {"x": "y", "y": "x"}[orient]
        ax = violinplot(long_df, **{var: col}, cut=0)
        poly = ax.collections[0]
        self.check_violin(poly, long_df[col], orient, 0)

    @pytest.mark.parametrize("orient,col", [(None, "x"), ("x", "y"), ("y", "z")])
    def test_vector_data(self, long_df, orient, col):

        orient = "x" if orient is None else orient
        ax = violinplot(long_df[col], cut=0, orient=orient)
        poly = ax.collections[0]
        self.check_violin(poly, long_df[col], orient, 0)

    @pytest.mark.parametrize("orient", ["h", "v"])
    def test_wide_data(self, wide_df, orient):

        orient = {"h": "y", "v": "x"}[orient]
        ax = violinplot(wide_df, cut=0, orient=orient)
        for i, poly in enumerate(ax.collections):
            col = wide_df.columns[i]
            self.check_violin(poly, wide_df[col], orient, i)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_grouped(self, long_df, orient):

        value = {"x": "y", "y": "x"}[orient]
        ax = violinplot(long_df, **{orient: "a", value: "z"}, cut=0)
        levels = categorical_order(long_df["a"])
        for i, level in enumerate(levels):
            data = long_df.loc[long_df["a"] == level, "z"]
            self.check_violin(ax.collections[i], data, orient, i)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_hue_grouped(self, long_df, orient):

        value = {"x": "y", "y": "x"}[orient]
        ax = violinplot(long_df, hue="c", **{orient: "a", value: "z"}, cut=0)
        polys = iter(ax.collections)
        for i, level in enumerate(categorical_order(long_df["a"])):
            for j, hue_level in enumerate(categorical_order(long_df["c"])):
                rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = i + [-.2, +.2][j]
                width = 0.4
                self.check_violin(next(polys), data, orient, pos, width)

    def test_hue_not_dodged(self, long_df):

        levels = categorical_order(long_df["b"])
        hue = long_df["b"].isin(levels[:2])
        ax = violinplot(long_df, x="b", y="z", hue=hue, cut=0)
        for i, level in enumerate(levels):
            poly = ax.collections[i]
            data = long_df.loc[long_df["b"] == level, "z"]
            self.check_violin(poly, data, "x", i)

    def test_dodge_native_scale(self, long_df):

        centers = categorical_order(long_df["s"])
        hue_levels = categorical_order(long_df["c"])
        spacing = min(np.diff(centers))
        width = 0.8 * spacing / len(hue_levels)
        offset = width / len(hue_levels)
        ax = violinplot(long_df, x="s", y="z", hue="c", native_scale=True, cut=0)
        violins = iter(ax.collections)
        for center in centers:
            for i, hue_level in enumerate(hue_levels):
                rows = (long_df["s"] == center) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = center + [-offset, +offset][i]
                poly = next(violins)
                self.check_violin(poly, data, "x", pos, width)

    def test_dodge_native_scale_log(self, long_df):

        pos = 10 ** long_df["s"]
        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        variables = dict(x=pos, y="z", hue="c")
        violinplot(long_df, **variables, native_scale=True, density_norm="width", ax=ax)
        widths = []
        n_violins = long_df["s"].nunique() * long_df["c"].nunique()
        for poly in ax.collections[:n_violins]:
            verts = poly.get_paths()[0].vertices[:, 0]
            coords = np.log10(verts)
            widths.append(np.ptp(coords))
        assert np.std(widths) == approx(0)

    def test_color(self, long_df):

        color = "#123456"
        ax = violinplot(long_df, x="a", y="y", color=color, saturation=1)
        for poly in ax.collections:
            assert same_color(poly.get_facecolor(), color)

    def test_hue_colors(self, long_df):

        ax = violinplot(long_df, x="a", y="y", hue="b", saturation=1)
        n_levels = long_df["b"].nunique()
        for i, poly in enumerate(ax.collections):
            assert same_color(poly.get_facecolor(), f"C{i % n_levels}")

    @pytest.mark.parametrize("inner", ["box", "quart", "stick", "point"])
    def test_linecolor(self, long_df, inner):

        color = "#669913"
        ax = violinplot(long_df, x="a", y="y", linecolor=color, inner=inner)
        for poly in ax.findobj(mpl.collections.PolyCollection):
            assert same_color(poly.get_edgecolor(), color)
        for lines in ax.findobj(mpl.collections.LineCollection):
            assert same_color(lines.get_color(), color)
        for line in ax.lines:
            assert same_color(line.get_color(), color)

    def test_linewidth(self, long_df):

        width = 5
        ax = violinplot(long_df, x="a", y="y", linewidth=width)
        poly = ax.collections[0]
        assert poly.get_linewidth() == width

    def test_saturation(self, long_df):

        color = "#8912b0"
        ax = violinplot(long_df["x"], color=color, saturation=.5)
        poly = ax.collections[0]
        assert np.allclose(poly.get_facecolors()[0, :3], desaturate(color, 0.5))

    @pytest.mark.parametrize("inner", ["box", "quart", "stick", "point"])
    def test_fill(self, long_df, inner):

        color = "#459900"
        ax = violinplot(x=long_df["z"], fill=False, color=color, inner=inner)
        for poly in ax.findobj(mpl.collections.PolyCollection):
            assert poly.get_facecolor().size == 0
            assert same_color(poly.get_edgecolor(), color)
        for lines in ax.findobj(mpl.collections.LineCollection):
            assert same_color(lines.get_color(), color)
        for line in ax.lines:
            assert same_color(line.get_color(), color)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_inner_box(self, long_df, orient):

        pos_idx, val_idx = self.orient_indices(orient)
        ax = violinplot(long_df["y"], orient=orient)
        stats = mpl.cbook.boxplot_stats(long_df["y"])[0]

        whiskers = ax.lines[0].get_xydata()
        assert whiskers[0, val_idx] == stats["whislo"]
        assert whiskers[1, val_idx] == stats["whishi"]
        assert whiskers[:, pos_idx].tolist() == [0, 0]

        box = ax.lines[1].get_xydata()
        assert box[0, val_idx] == stats["q1"]
        assert box[1, val_idx] == stats["q3"]
        assert box[:, pos_idx].tolist() == [0, 0]

        median = ax.lines[2].get_xydata()
        assert median[0, val_idx] == stats["med"]
        assert median[0, pos_idx] == 0

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_inner_quartiles(self, long_df, orient):

        pos_idx, val_idx = self.orient_indices(orient)
        ax = violinplot(long_df["y"], orient=orient, inner="quart")
        quartiles = np.percentile(long_df["y"], [25, 50, 75])

        for q, line in zip(quartiles, ax.lines):
            pts = line.get_xydata()
            for pt in pts:
                assert pt[val_idx] == q
            assert pts[0, pos_idx] == -pts[1, pos_idx]

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_inner_stick(self, long_df, orient):

        pos_idx, val_idx = self.orient_indices(orient)
        ax = violinplot(long_df["y"], orient=orient, inner="stick")
        for i, pts in enumerate(ax.collections[1].get_segments()):
            for pt in pts:
                assert pt[val_idx] == long_df["y"].iloc[i]
            assert pts[0, pos_idx] == -pts[1, pos_idx]

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_inner_points(self, long_df, orient):

        pos_idx, val_idx = self.orient_indices(orient)
        ax = violinplot(long_df["y"], orient=orient, inner="points")
        points = ax.collections[1]
        for i, pt in enumerate(points.get_offsets()):
            assert pt[val_idx] == long_df["y"].iloc[i]
            assert pt[pos_idx] == 0

    def test_split_single(self, long_df):

        ax = violinplot(long_df, x="a", y="z", split=True, cut=0)
        levels = categorical_order(long_df["a"])
        for i, level in enumerate(levels):
            data = long_df.loc[long_df["a"] == level, "z"]
            self.check_violin(ax.collections[i], data, "x", i)
            verts = ax.collections[i].get_paths()[0].vertices
            assert np.isclose(verts[:, 0], i + .4).sum() >= 100

    def test_split_multi(self, long_df):

        ax = violinplot(long_df, x="a", y="z", hue="c", split=True, cut=0)
        polys = iter(ax.collections)
        for i, level in enumerate(categorical_order(long_df["a"])):
            for j, hue_level in enumerate(categorical_order(long_df["c"])):
                rows = (long_df["a"] == level) & (long_df["c"] == hue_level)
                data = long_df.loc[rows, "z"]
                pos = i + [-.2, +.2][j]
                poly = next(polys)
                self.check_violin(poly, data, "x", pos, width=0.4)
                verts = poly.get_paths()[0].vertices
                assert np.isclose(verts[:, 0], i).sum() >= 100

    def test_density_norm_area(self, long_df):

        y = long_df["y"].to_numpy()
        ax = violinplot([y, y * 5], color="C0")
        widths = []
        for poly in ax.collections:
            widths.append(self.violin_width(poly))
        assert widths[0] / widths[1] == approx(5)

    def test_density_norm_count(self, long_df):

        y = long_df["y"].to_numpy()
        ax = violinplot([np.repeat(y, 3), y], density_norm="count", color="C0")
        widths = []
        for poly in ax.collections:
            widths.append(self.violin_width(poly))
        assert widths[0] / widths[1] == approx(3)

    def test_density_norm_width(self, long_df):

        ax = violinplot(long_df, x="a", y="y", density_norm="width")
        for poly in ax.collections:
            assert self.violin_width(poly) == approx(0.8)

    def test_common_norm(self, long_df):

        ax = violinplot(long_df, x="a", y="y", hue="c", common_norm=True)
        widths = []
        for poly in ax.collections:
            widths.append(self.violin_width(poly))
        assert sum(w > 0.3999 for w in widths) == 1

    def test_scale_deprecation(self, long_df):

        with pytest.warns(FutureWarning, match=r".+Pass `density_norm='count'`"):
            violinplot(long_df, x="a", y="y", hue="b", scale="count")

    def test_scale_hue_deprecation(self, long_df):

        with pytest.warns(FutureWarning, match=r".+Pass `common_norm=True`"):
            violinplot(long_df, x="a", y="y", hue="b", scale_hue=False)

    def test_bw_adjust(self, long_df):

        ax = violinplot(long_df["y"], bw_adjust=.2)
        violinplot(long_df["y"], bw_adjust=2)
        kde1 = ax.collections[0].get_paths()[0].vertices[:100, 0]
        kde2 = ax.collections[1].get_paths()[0].vertices[:100, 0]
        assert np.std(np.diff(kde1)) > np.std(np.diff(kde2))

    def test_bw_deprecation(self, long_df):

        with pytest.warns(FutureWarning, match=r".*Setting `bw_method='silverman'`"):
            violinplot(long_df["y"], bw="silverman")

    def test_gap(self, long_df):

        ax = violinplot(long_df, y="y", hue="c", gap=.2)
        a = ax.collections[0].get_paths()[0].vertices[:, 0].max()
        b = ax.collections[1].get_paths()[0].vertices[:, 0].min()
        assert (b - a) == approx(0.2 * 0.8 / 2)

    def test_inner_kws(self, long_df):

        kws = {"linewidth": 3}
        ax = violinplot(long_df, x="a", y="y", inner="stick", inner_kws=kws)
        for line in ax.lines:
            assert line.get_linewidth() == kws["linewidth"]

    def test_box_inner_kws(self, long_df):

        kws = {"box_width": 10, "whis_width": 2, "marker": "x"}
        ax = violinplot(long_df, x="a", y="y", inner_kws=kws)
        for line in ax.lines[::3]:
            assert line.get_linewidth() == kws["whis_width"]
        for line in ax.lines[1::3]:
            assert line.get_linewidth() == kws["box_width"]
        for line in ax.lines[2::3]:
            assert line.get_marker() == kws["marker"]

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data="wide"),
            dict(data="wide", orient="h"),
            dict(data="flat"),
            dict(data="long", x="a", y="y"),
            dict(data=None, x="a", y="y", split=True),
            dict(data="long", x="a", y="y", hue="a"),
            dict(data=None, x="a", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="b"),
            dict(data=None, x="s", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="s", split=True),
            dict(data="null", x="a", y="y", hue="a"),
            dict(data="long", x="s", y="y", hue="a", native_scale=True),
            dict(data="long", x="d", y="y", hue="a", native_scale=True),
            dict(data="null", x="a", y="y", hue="b", fill=False, gap=.2),
            dict(data="null", x="a", y="y", linecolor="r", linewidth=5),
            dict(data="long", x="a", y="y", inner="stick"),
            dict(data="long", x="a", y="y", inner="points"),
            dict(data="long", x="a", y="y", hue="b", inner="quartiles", split=True),
            dict(data="long", x="a", y="y", density_norm="count", common_norm=True),
            dict(data="long", x="a", y="y", bw_adjust=2),
        ]
    )
    def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):

        if kwargs["data"] == "long":
            kwargs["data"] = long_df
        elif kwargs["data"] == "wide":
            kwargs["data"] = wide_df
        elif kwargs["data"] == "flat":
            kwargs["data"] = flat_series
        elif kwargs["data"] == "null":
            kwargs["data"] = null_df
        elif kwargs["data"] is None:
            for var in ["x", "y", "hue"]:
                if var in kwargs:
                    kwargs[var] = long_df[kwargs[var]]

        ax = violinplot(**kwargs)
        g = catplot(**kwargs, kind="violin")

        assert_plots_equal(ax, g.ax)


class TestBarPlot(SharedAggTests):

    func = staticmethod(barplot)

    @pytest.fixture
    def common_kws(self):
        return {"saturation": 1}

    def get_last_color(self, ax):

        colors = [p.get_facecolor() for p in ax.containers[-1]]
        unique_colors = np.unique(colors, axis=0)
        assert len(unique_colors) == 1
        return to_rgba(unique_colors.squeeze())

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_single_var(self, orient):

        vals = pd.Series([1, 3, 10])
        ax = barplot(**{orient: vals})
        bar, = ax.patches
        prop = {"x": "width", "y": "height"}[orient]
        assert getattr(bar, f"get_{prop}")() == approx(vals.mean())

    @pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
    def test_wide_df(self, wide_df, orient):

        ax = barplot(wide_df, orient=orient)
        orient = {"h": "y", "v": "x"}.get(orient, orient)
        prop = {"x": "height", "y": "width"}[orient]
        for i, bar in enumerate(ax.patches):
            assert getattr(bar, f"get_{prop}")() == approx(wide_df.iloc[:, i].mean())

    @pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
    def test_vector_orient(self, orient):

        keys, vals = ["a", "b", "c"], [1, 2, 3]
        data = dict(zip(keys, vals))
        orient = {"h": "y", "v": "x"}.get(orient, orient)
        prop = {"x": "height", "y": "width"}[orient]
        ax = barplot(data, orient=orient)
        for i, bar in enumerate(ax.patches):
            assert getattr(bar, f"get_{orient}")() == approx(i - 0.4)
            assert getattr(bar, f"get_{prop}")() == approx(vals[i])

    def test_xy_vertical(self):

        x, y = ["a", "b", "c"], [1, 3, 2.5]

        ax = barplot(x=x, y=y)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_y() == approx(0)
            assert bar.get_height() == approx(y[i])
            assert bar.get_width() == approx(0.8)

    def test_xy_horizontal(self):

        x, y = [1, 3, 2.5], ["a", "b", "c"]

        ax = barplot(x=x, y=y)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() == approx(0)
            assert bar.get_y() + bar.get_height() / 2 == approx(i)
            assert bar.get_height() == approx(0.8)
            assert bar.get_width() == approx(x[i])

    def test_xy_with_na_grouper(self):

        x, y = ["a", None, "b"], [1, 2, 3]
        ax = barplot(x=x, y=y)
        _draw_figure(ax.figure)  # For matplotlib<3.5
        assert ax.get_xticks() == [0, 1]
        assert [t.get_text() for t in ax.get_xticklabels()] == ["a", "b"]
        assert ax.patches[0].get_height() == 1
        assert ax.patches[1].get_height() == 3

    def test_xy_with_na_value(self):

        x, y = ["a", "b", "c"], [1, None, 3]
        ax = barplot(x=x, y=y)
        _draw_figure(ax.figure)  # For matplotlib<3.5
        assert ax.get_xticks() == [0, 1, 2]
        assert [t.get_text() for t in ax.get_xticklabels()] == ["a", "b", "c"]
        assert ax.patches[0].get_height() == 1
        assert ax.patches[1].get_height() == 3

    def test_hue_redundant(self):

        x, y = ["a", "b", "c"], [1, 2, 3]

        ax = barplot(x=x, y=y, hue=x, saturation=1)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_y() == 0
            assert bar.get_height() == y[i]
            assert bar.get_width() == approx(0.8)
            assert same_color(bar.get_facecolor(), f"C{i}")

    def test_hue_matched(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        hue = ["x", "x", "y"]

        ax = barplot(x=x, y=y, hue=hue, saturation=1, legend=False)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_y() == 0
            assert bar.get_height() == y[i]
            assert bar.get_width() == approx(0.8)
            assert same_color(bar.get_facecolor(), f"C{i // 2}")

    def test_hue_matched_by_name(self):

        data = {"x": ["a", "b", "c"], "y": [1, 2, 3]}
        ax = barplot(data, x="x", y="y", hue="x", saturation=1)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_y() == 0
            assert bar.get_height() == data["y"][i]
            assert bar.get_width() == approx(0.8)
            assert same_color(bar.get_facecolor(), f"C{i}")

    def test_hue_dodged(self):

        x = ["a", "b", "a", "b"]
        y = [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]

        ax = barplot(x=x, y=y, hue=hue, saturation=1, legend=False)
        for i, bar in enumerate(ax.patches):
            sign = 1 if i // 2 else -1
            assert (
                bar.get_x() + bar.get_width() / 2
                == approx(i % 2 + sign * 0.8 / 4)
            )
            assert bar.get_y() == 0
            assert bar.get_height() == y[i]
            assert bar.get_width() == approx(0.8 / 2)
            assert same_color(bar.get_facecolor(), f"C{i // 2}")

    def test_gap(self):

        x = ["a", "b", "a", "b"]
        y = [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]

        ax = barplot(x=x, y=y, hue=hue, gap=.25, legend=False)
        for i, bar in enumerate(ax.patches):
            assert bar.get_width() == approx(0.8 / 2 * .75)

    def test_hue_undodged(self):

        x = ["a", "b", "a", "b"]
        y = [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]

        ax = barplot(x=x, y=y, hue=hue, saturation=1, dodge=False, legend=False)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i % 2)
            assert bar.get_y() == 0
            assert bar.get_height() == y[i]
            assert bar.get_width() == approx(0.8)
            assert same_color(bar.get_facecolor(), f"C{i // 2}")

    def test_hue_order(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        hue_order = ["c", "b", "a"]

        ax = barplot(x=x, y=y, hue=x, hue_order=hue_order, saturation=1)
        for i, bar in enumerate(ax.patches):
            assert same_color(bar.get_facecolor(), f"C{i}")
            assert bar.get_x() + bar.get_width() / 2 == approx(2 - i)

    def test_hue_norm(self):

        x, y = [1, 2, 3, 4], [1, 2, 3, 4]

        ax = barplot(x=x, y=y, hue=x, hue_norm=(2, 3))
        colors = [bar.get_facecolor() for bar in ax.patches]
        assert colors[0] == colors[1]
        assert colors[1] != colors[2]
        assert colors[2] == colors[3]

    def test_fill(self):

        x = ["a", "b", "a", "b"]
        y = [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]

        ax = barplot(x=x, y=y, hue=hue, fill=False, legend=False)
        for i, bar in enumerate(ax.patches):
            assert same_color(bar.get_edgecolor(), f"C{i // 2}")
            assert same_color(bar.get_facecolor(), (0, 0, 0, 0))

    def test_xy_native_scale(self):

        x, y = [2, 4, 8], [1, 2, 3]

        ax = barplot(x=x, y=y, native_scale=True)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(x[i])
            assert bar.get_y() == 0
            assert bar.get_height() == y[i]
            assert bar.get_width() == approx(0.8 * 2)

    def test_xy_native_scale_log_transform(self):

        x, y = [1, 10, 100], [1, 2, 3]

        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        barplot(x=x, y=y, native_scale=True, ax=ax)
        for i, bar in enumerate(ax.patches):
            x0, x1 = np.log10([bar.get_x(), bar.get_x() + bar.get_width()])
            center = 10 ** (x0 + (x1 - x0) / 2)
            assert center == approx(x[i])
            assert bar.get_y() == 0
            assert bar.get_height() == y[i]
        assert ax.patches[1].get_width() > ax.patches[0].get_width()

    def test_datetime_native_scale_axis(self):

        x = pd.date_range("2010-01-01", periods=20, freq="MS")
        y = np.arange(20)
        ax = barplot(x=x, y=y, native_scale=True)
        assert "Date" in ax.xaxis.get_major_locator().__class__.__name__
        day = "2003-02-28"
        assert_array_equal(ax.xaxis.convert_units([day]), mpl.dates.date2num([day]))

    def test_native_scale_dodged(self):

        x, y = [2, 4, 2, 4], [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]

        ax = barplot(x=x, y=y, hue=hue, native_scale=True)

        for x_i, bar in zip(x[:2], ax.patches[:2]):
            assert bar.get_x() + bar.get_width() == approx(x_i)
        for x_i, bar in zip(x[2:], ax.patches[2:]):
            assert bar.get_x() == approx(x_i)

    def test_native_scale_log_transform_dodged(self):

        x, y = [1, 100, 1, 100], [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]

        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        barplot(x=x, y=y, hue=hue, native_scale=True, ax=ax)

        for x_i, bar in zip(x[:2], ax.patches[:2]):
            assert bar.get_x() + bar.get_width() == approx(x_i)
        for x_i, bar in zip(x[2:], ax.patches[2:]):
            assert bar.get_x() == approx(x_i)

    def test_estimate_default(self, long_df):

        agg_var, val_var = "a", "y"
        agg_df = long_df.groupby(agg_var)[val_var].mean()

        ax = barplot(long_df, x=agg_var, y=val_var, errorbar=None)
        order = categorical_order(long_df[agg_var])
        for i, bar in enumerate(ax.patches):
            assert bar.get_height() == approx(agg_df[order[i]])

    def test_estimate_string(self, long_df):

        agg_var, val_var = "a", "y"
        agg_df = long_df.groupby(agg_var)[val_var].median()

        ax = barplot(long_df, x=agg_var, y=val_var, estimator="median", errorbar=None)
        order = categorical_order(long_df[agg_var])
        for i, bar in enumerate(ax.patches):
            assert bar.get_height() == approx(agg_df[order[i]])

    def test_estimate_func(self, long_df):

        agg_var, val_var = "a", "y"
        agg_df = long_df.groupby(agg_var)[val_var].median()

        ax = barplot(long_df, x=agg_var, y=val_var, estimator=np.median, errorbar=None)
        order = categorical_order(long_df[agg_var])
        for i, bar in enumerate(ax.patches):
            assert bar.get_height() == approx(agg_df[order[i]])

    def test_weighted_estimate(self, long_df):

        ax = barplot(long_df, y="y", weights="x")
        height = ax.patches[0].get_height()
        expected = np.average(long_df["y"], weights=long_df["x"])
        assert height == expected

    def test_estimate_log_transform(self, long_df):

        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        barplot(x=long_df["z"], ax=ax)
        bar, = ax.patches
        assert bar.get_width() == 10 ** np.log10(long_df["z"]).mean()

    def test_errorbars(self, long_df):

        agg_var, val_var = "a", "y"
        agg_df = long_df.groupby(agg_var)[val_var].agg(["mean", "std"])

        ax = barplot(long_df, x=agg_var, y=val_var, errorbar="sd")
        order = categorical_order(long_df[agg_var])
        for i, line in enumerate(ax.lines):
            row = agg_df.loc[order[i]]
            lo, hi = line.get_ydata()
            assert lo == approx(row["mean"] - row["std"])
            assert hi == approx(row["mean"] + row["std"])

    def test_width(self):

        width = .5
        x, y = ["a", "b", "c"], [1, 2, 3]
        ax = barplot(x=x, y=y, width=width)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_width() == width

    def test_width_native_scale(self):

        width = .5
        x, y = [4, 6, 10], [1, 2, 3]
        ax = barplot(x=x, y=y, width=width, native_scale=True)
        for bar in ax.patches:
            assert bar.get_width() == (width * 2)

    def test_width_spaced_categories(self):

        ax = barplot(x=["a", "b", "c"], y=[4, 5, 6])
        barplot(x=["a", "c"], y=[1, 3], ax=ax)
        for bar in ax.patches:
            assert bar.get_width() == pytest.approx(0.8)

    def test_saturation_color(self):

        color = (.1, .9, .2)
        x, y = ["a", "b", "c"], [1, 2, 3]
        ax = barplot(x=x, y=y)
        for bar in ax.patches:
            assert np.var(bar.get_facecolor()[:3]) < np.var(color)

    def test_saturation_palette(self):

        palette = color_palette("viridis", 3)
        x, y = ["a", "b", "c"], [1, 2, 3]
        ax = barplot(x=x, y=y, hue=x, palette=palette)
        for i, bar in enumerate(ax.patches):
            assert np.var(bar.get_facecolor()[:3]) < np.var(palette[i])

    def test_legend_numeric_auto(self, long_df):

        ax = barplot(long_df, x="x", y="y", hue="x")
        assert len(ax.get_legend().texts) <= 6

    def test_legend_numeric_full(self, long_df):

        ax = barplot(long_df, x="x", y="y", hue="x", legend="full")
        labels = [t.get_text() for t in ax.get_legend().texts]
        levels = [str(x) for x in sorted(long_df["x"].unique())]
        assert labels == levels

    def test_legend_disabled(self, long_df):

        ax = barplot(long_df, x="x", y="y", hue="b", legend=False)
        assert ax.get_legend() is None

    def test_error_caps(self):

        x, y = ["a", "b", "c"] * 2, [1, 2, 3, 4, 5, 6]
        ax = barplot(x=x, y=y, capsize=.8, errorbar="pi")

        assert len(ax.patches) == len(ax.lines)
        for bar, error in zip(ax.patches, ax.lines):
            pos = error.get_xdata()
            assert len(pos) == 8
            assert np.nanmin(pos) == approx(bar.get_x())
            assert np.nanmax(pos) == approx(bar.get_x() + bar.get_width())

    def test_error_caps_native_scale(self):

        x, y = [2, 4, 20] * 2, [1, 2, 3, 4, 5, 6]
        ax = barplot(x=x, y=y, capsize=.8, native_scale=True, errorbar="pi")

        assert len(ax.patches) == len(ax.lines)
        for bar, error in zip(ax.patches, ax.lines):
            pos = error.get_xdata()
            assert len(pos) == 8
            assert np.nanmin(pos) == approx(bar.get_x())
            assert np.nanmax(pos) == approx(bar.get_x() + bar.get_width())

    def test_error_caps_native_scale_log_transform(self):

        x, y = [1, 10, 1000] * 2, [1, 2, 3, 4, 5, 6]
        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        barplot(x=x, y=y, capsize=.8, native_scale=True, errorbar="pi", ax=ax)

        assert len(ax.patches) == len(ax.lines)
        for bar, error in zip(ax.patches, ax.lines):
            pos = error.get_xdata()
            assert len(pos) == 8
            assert np.nanmin(pos) == approx(bar.get_x())
            assert np.nanmax(pos) == approx(bar.get_x() + bar.get_width())

    def test_bar_kwargs(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        kwargs = dict(linewidth=3, facecolor=(.5, .4, .3, .2), rasterized=True)
        ax = barplot(x=x, y=y, **kwargs)
        for bar in ax.patches:
            assert bar.get_linewidth() == kwargs["linewidth"]
            assert bar.get_facecolor() == kwargs["facecolor"]
            assert bar.get_rasterized() == kwargs["rasterized"]

    def test_legend_attributes(self, long_df):

        palette = color_palette()
        ax = barplot(
            long_df, x="a", y="y", hue="c", saturation=1, edgecolor="k", linewidth=3
        )
        for i, patch in enumerate(get_legend_handles(ax.get_legend())):
            assert same_color(patch.get_facecolor(), palette[i])
            assert same_color(patch.get_edgecolor(), "k")
            assert patch.get_linewidth() == 3

    def test_legend_unfilled(self, long_df):

        palette = color_palette()
        ax = barplot(long_df, x="a", y="y", hue="c", fill=False, linewidth=3)
        for i, patch in enumerate(get_legend_handles(ax.get_legend())):
            assert patch.get_facecolor() == (0, 0, 0, 0)
            assert same_color(patch.get_edgecolor(), palette[i])
            assert patch.get_linewidth() == 3

    @pytest.mark.parametrize("fill", [True, False])
    def test_err_kws(self, fill):

        x, y = ["a", "b", "c"], [1, 2, 3]
        err_kws = dict(color=(1, 1, .5, .5), linewidth=5)
        ax = barplot(x=x, y=y, fill=fill, err_kws=err_kws)
        for line in ax.lines:
            assert line.get_color() == err_kws["color"]
            assert line.get_linewidth() == err_kws["linewidth"]

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data="wide"),
            dict(data="wide", orient="h"),
            dict(data="flat"),
            dict(data="long", x="a", y="y"),
            dict(data=None, x="a", y="y"),
            dict(data="long", x="a", y="y", hue="a"),
            dict(data=None, x="a", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="b"),
            dict(data=None, x="s", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="s"),
            dict(data="long", x="a", y="y", units="c"),
            dict(data="null", x="a", y="y", hue="a", gap=.1, fill=False),
            dict(data="long", x="s", y="y", hue="a", native_scale=True),
            dict(data="long", x="d", y="y", hue="a", native_scale=True),
            dict(data="long", x="a", y="y", errorbar=("pi", 50)),
            dict(data="long", x="a", y="y", errorbar=None),
            dict(data="long", x="a", y="y", capsize=.3, err_kws=dict(c="k")),
            dict(data="long", x="a", y="y", color="blue", edgecolor="green", alpha=.5),
        ]
    )
    def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):

        kwargs = kwargs.copy()
        kwargs["seed"] = 0
        kwargs["n_boot"] = 10

        if kwargs["data"] == "long":
            kwargs["data"] = long_df
        elif kwargs["data"] == "wide":
            kwargs["data"] = wide_df
        elif kwargs["data"] == "flat":
            kwargs["data"] = flat_series
        elif kwargs["data"] == "null":
            kwargs["data"] = null_df
        elif kwargs["data"] is None:
            for var in ["x", "y", "hue"]:
                if var in kwargs:
                    kwargs[var] = long_df[kwargs[var]]

        ax = barplot(**kwargs)
        g = catplot(**kwargs, kind="bar")

        assert_plots_equal(ax, g.ax)

    def test_errwidth_deprecation(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        val = 5
        with pytest.warns(FutureWarning, match="\n\nThe `errwidth` parameter"):
            ax = barplot(x=x, y=y, errwidth=val)
        for line in ax.lines:
            assert line.get_linewidth() == val

    def test_errcolor_deprecation(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        val = (1, .7, .4, .8)
        with pytest.warns(FutureWarning, match="\n\nThe `errcolor` parameter"):
            ax = barplot(x=x, y=y, errcolor=val)
        for line in ax.lines:
            assert line.get_color() == val

    def test_capsize_as_none_deprecation(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        with pytest.warns(FutureWarning, match="\n\nPassing `capsize=None`"):
            ax = barplot(x=x, y=y, capsize=None)
        for line in ax.lines:
            assert len(line.get_xdata()) == 2

    def test_hue_implied_by_palette_deprecation(self):

        x = ["a", "b", "c"]
        y = [1, 2, 3]
        palette = "Set1"
        colors = color_palette(palette, len(x))
        msg = "Passing `palette` without assigning `hue` is deprecated."
        with pytest.warns(FutureWarning, match=msg):
            ax = barplot(x=x, y=y, saturation=1, palette=palette)
        for i, bar in enumerate(ax.patches):
            assert same_color(bar.get_facecolor(), colors[i])


class TestPointPlot(SharedAggTests):

    func = staticmethod(pointplot)

    def get_last_color(self, ax):

        color = ax.lines[-1].get_color()
        return to_rgba(color)

    @pytest.mark.parametrize("orient", ["x", "y"])
    def test_single_var(self, orient):

        vals = pd.Series([1, 3, 10])
        ax = pointplot(**{orient: vals})
        line = ax.lines[0]
        assert getattr(line, f"get_{orient}data")() == approx(vals.mean())

    @pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
    def test_wide_df(self, wide_df, orient):

        ax = pointplot(wide_df, orient=orient)
        orient = {"h": "y", "v": "x"}.get(orient, orient)
        depend = {"x": "y", "y": "x"}[orient]
        line = ax.lines[0]
        assert_array_equal(
            getattr(line, f"get_{orient}data")(),
            np.arange(len(wide_df.columns)),
        )
        assert_array_almost_equal(
            getattr(line, f"get_{depend}data")(),
            wide_df.mean(axis=0),
        )

    @pytest.mark.parametrize("orient", ["x", "y", "h", "v"])
    def test_vector_orient(self, orient):

        keys, vals = ["a", "b", "c"], [1, 2, 3]
        data = dict(zip(keys, vals))
        orient = {"h": "y", "v": "x"}.get(orient, orient)
        depend = {"x": "y", "y": "x"}[orient]
        ax = pointplot(data, orient=orient)
        line = ax.lines[0]
        assert_array_equal(
            getattr(line, f"get_{orient}data")(),
            np.arange(len(keys)),
        )
        assert_array_equal(getattr(line, f"get_{depend}data")(), vals)

    def test_xy_vertical(self):

        x, y = ["a", "b", "c"], [1, 3, 2.5]
        ax = pointplot(x=x, y=y)
        for i, xy in enumerate(ax.lines[0].get_xydata()):
            assert tuple(xy) == (i, y[i])

    def test_xy_horizontal(self):

        x, y = [1, 3, 2.5], ["a", "b", "c"]
        ax = pointplot(x=x, y=y)
        for i, xy in enumerate(ax.lines[0].get_xydata()):
            assert tuple(xy) == (x[i], i)

    def test_xy_with_na_grouper(self):

        x, y = ["a", None, "b"], [1, 2, 3]
        ax = pointplot(x=x, y=y)
        _draw_figure(ax.figure)  # For matplotlib<3.5
        assert ax.get_xticks() == [0, 1]
        assert [t.get_text() for t in ax.get_xticklabels()] == ["a", "b"]
        assert_array_equal(ax.lines[0].get_xdata(), [0, 1])
        assert_array_equal(ax.lines[0].get_ydata(), [1, 3])

    def test_xy_with_na_value(self):

        x, y = ["a", "b", "c"], [1, np.nan, 3]
        ax = pointplot(x=x, y=y)
        _draw_figure(ax.figure)  # For matplotlib<3.5
        assert ax.get_xticks() == [0, 1, 2]
        assert [t.get_text() for t in ax.get_xticklabels()] == x
        assert_array_equal(ax.lines[0].get_xdata(), [0, 1, 2])
        assert_array_equal(ax.lines[0].get_ydata(), y)

    def test_hue(self):

        x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
        hue = ["x", "y", "x", "y"]
        ax = pointplot(x=x, y=y, hue=hue, errorbar=None)
        for i, line in enumerate(ax.lines[:2]):
            assert_array_equal(line.get_ydata(), y[i::2])
            assert same_color(line.get_color(), f"C{i}")

    def test_wide_data_is_joined(self, wide_df):

        ax = pointplot(wide_df, errorbar=None)
        assert len(ax.lines) == 1

    def test_xy_native_scale(self):

        x, y = [2, 4, 8], [1, 2, 3]

        ax = pointplot(x=x, y=y, native_scale=True)
        line = ax.lines[0]
        assert_array_equal(line.get_xdata(), x)
        assert_array_equal(line.get_ydata(), y)

    # Use lambda around np.mean to avoid uninformative pandas deprecation warning
    @pytest.mark.parametrize("estimator", ["mean", lambda x: np.mean(x)])
    def test_estimate(self, long_df, estimator):

        agg_var, val_var = "a", "y"
        agg_df = long_df.groupby(agg_var)[val_var].agg(estimator)

        ax = pointplot(long_df, x=agg_var, y=val_var, errorbar=None)
        order = categorical_order(long_df[agg_var])
        for i, xy in enumerate(ax.lines[0].get_xydata()):
            assert tuple(xy) == approx((i, agg_df[order[i]]))

    def test_weighted_estimate(self, long_df):

        ax = pointplot(long_df, y="y", weights="x")
        val = ax.lines[0].get_ydata().item()
        expected = np.average(long_df["y"], weights=long_df["x"])
        assert val == expected

    def test_estimate_log_transform(self, long_df):

        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        pointplot(x=long_df["z"], ax=ax)
        val, = ax.lines[0].get_xdata()
        assert val == 10 ** np.log10(long_df["z"]).mean()

    def test_errorbars(self, long_df):

        agg_var, val_var = "a", "y"
        agg_df = long_df.groupby(agg_var)[val_var].agg(["mean", "std"])

        ax = pointplot(long_df, x=agg_var, y=val_var, errorbar="sd")
        order = categorical_order(long_df[agg_var])
        for i, line in enumerate(ax.lines[1:]):
            row = agg_df.loc[order[i]]
            lo, hi = line.get_ydata()
            assert lo == approx(row["mean"] - row["std"])
            assert hi == approx(row["mean"] + row["std"])

    def test_marker_linestyle(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        ax = pointplot(x=x, y=y, marker="s", linestyle="--")
        line = ax.lines[0]
        assert line.get_marker() == "s"
        assert line.get_linestyle() == "--"

    def test_markers_linestyles_single(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        ax = pointplot(x=x, y=y, markers="s", linestyles="--")
        line = ax.lines[0]
        assert line.get_marker() == "s"
        assert line.get_linestyle() == "--"

    def test_markers_linestyles_mapped(self):

        x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
        hue = ["x", "y", "x", "y"]
        markers = ["d", "s"]
        linestyles = ["--", ":"]
        ax = pointplot(
            x=x, y=y, hue=hue,
            markers=markers, linestyles=linestyles,
            errorbar=None,
        )
        for i, line in enumerate(ax.lines[:2]):
            assert line.get_marker() == markers[i]
            assert line.get_linestyle() == linestyles[i]

    def test_dodge_boolean(self):

        x, y = ["a", "b", "a", "b"], [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]
        ax = pointplot(x=x, y=y, hue=hue, dodge=True, errorbar=None)
        for i, xy in enumerate(ax.lines[0].get_xydata()):
            assert tuple(xy) == (i - .025, y[i])
        for i, xy in enumerate(ax.lines[1].get_xydata()):
            assert tuple(xy) == (i + .025, y[2 + i])

    def test_dodge_float(self):

        x, y = ["a", "b", "a", "b"], [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]
        ax = pointplot(x=x, y=y, hue=hue, dodge=.2, errorbar=None)
        for i, xy in enumerate(ax.lines[0].get_xydata()):
            assert tuple(xy) == (i - .1, y[i])
        for i, xy in enumerate(ax.lines[1].get_xydata()):
            assert tuple(xy) == (i + .1, y[2 + i])

    def test_dodge_log_scale(self):

        x, y = [10, 1000, 10, 1000], [1, 2, 3, 4]
        hue = ["x", "x", "y", "y"]
        ax = mpl.figure.Figure().subplots()
        ax.set_xscale("log")
        pointplot(x=x, y=y, hue=hue, dodge=.2, native_scale=True, errorbar=None, ax=ax)
        for i, xy in enumerate(ax.lines[0].get_xydata()):
            assert tuple(xy) == approx((10 ** (np.log10(x[i]) - .2), y[i]))
        for i, xy in enumerate(ax.lines[1].get_xydata()):
            assert tuple(xy) == approx((10 ** (np.log10(x[2 + i]) + .2), y[2 + i]))

    def test_err_kws(self):

        x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
        err_kws = dict(color=(.2, .5, .3), linewidth=10)
        ax = pointplot(x=x, y=y, errorbar=("pi", 100), err_kws=err_kws)
        for line in ax.lines[1:]:
            assert same_color(line.get_color(), err_kws["color"])
            assert line.get_linewidth() == err_kws["linewidth"]

    def test_err_kws_inherited(self):

        x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
        kws = dict(color=(.2, .5, .3), linewidth=10)
        ax = pointplot(x=x, y=y, errorbar=("pi", 100), **kws)
        for line in ax.lines[1:]:
            assert same_color(line.get_color(), kws["color"])
            assert line.get_linewidth() == kws["linewidth"]

    @pytest.mark.skipif(
        _version_predates(mpl, "3.6"),
        reason="Legend handle missing marker property"
    )
    def test_legend_contents(self):

        x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
        hue = ["x", "y", "x", "y"]
        ax = pointplot(x=x, y=y, hue=hue)
        _draw_figure(ax.figure)
        legend = ax.get_legend()
        assert [t.get_text() for t in legend.texts] == ["x", "y"]
        for i, handle in enumerate(get_legend_handles(legend)):
            assert handle.get_marker() == "o"
            assert handle.get_linestyle() == "-"
            assert same_color(handle.get_color(), f"C{i}")

    @pytest.mark.skipif(
        _version_predates(mpl, "3.6"),
        reason="Legend handle missing marker property"
    )
    def test_legend_set_props(self):

        x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
        hue = ["x", "y", "x", "y"]
        kws = dict(marker="s", linewidth=1)
        ax = pointplot(x=x, y=y, hue=hue, **kws)
        legend = ax.get_legend()
        for i, handle in enumerate(get_legend_handles(legend)):
            assert handle.get_marker() == kws["marker"]
            assert handle.get_linewidth() == kws["linewidth"]

    @pytest.mark.skipif(
        _version_predates(mpl, "3.6"),
        reason="Legend handle missing marker property"
    )
    def test_legend_synced_props(self):

        x, y = ["a", "a", "b", "b"], [1, 2, 3, 4]
        hue = ["x", "y", "x", "y"]
        kws = dict(markers=["s", "d"], linestyles=["--", ":"])
        ax = pointplot(x=x, y=y, hue=hue, **kws)
        legend = ax.get_legend()
        for i, handle in enumerate(get_legend_handles(legend)):
            assert handle.get_marker() == kws["markers"][i]
            assert handle.get_linestyle() == kws["linestyles"][i]

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data="wide"),
            dict(data="wide", orient="h"),
            dict(data="flat"),
            dict(data="long", x="a", y="y"),
            dict(data=None, x="a", y="y"),
            dict(data="long", x="a", y="y", hue="a"),
            dict(data=None, x="a", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="b"),
            dict(data=None, x="s", y="y", hue="a"),
            dict(data="long", x="a", y="y", hue="s"),
            dict(data="long", x="a", y="y", units="c"),
            dict(data="null", x="a", y="y", hue="a"),
            dict(data="long", x="s", y="y", hue="a", native_scale=True),
            dict(data="long", x="d", y="y", hue="a", native_scale=True),
            dict(data="long", x="a", y="y", errorbar=("pi", 50)),
            dict(data="long", x="a", y="y", errorbar=None),
            dict(data="null", x="a", y="y", hue="a", dodge=True),
            dict(data="null", x="a", y="y", hue="a", dodge=.2),
            dict(data="long", x="a", y="y", capsize=.3, err_kws=dict(c="k")),
            dict(data="long", x="a", y="y", color="blue", marker="s"),
            dict(data="long", x="a", y="y", hue="a", markers=["s", "d", "p"]),
        ]
    )
    def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):

        kwargs = kwargs.copy()
        kwargs["seed"] = 0
        kwargs["n_boot"] = 10

        if kwargs["data"] == "long":
            kwargs["data"] = long_df
        elif kwargs["data"] == "wide":
            kwargs["data"] = wide_df
        elif kwargs["data"] == "flat":
            kwargs["data"] = flat_series
        elif kwargs["data"] == "null":
            kwargs["data"] = null_df
        elif kwargs["data"] is None:
            for var in ["x", "y", "hue"]:
                if var in kwargs:
                    kwargs[var] = long_df[kwargs[var]]

        ax = pointplot(**kwargs)
        g = catplot(**kwargs, kind="point")

        assert_plots_equal(ax, g.ax)

    def test_legend_disabled(self, long_df):

        ax = pointplot(long_df, x="x", y="y", hue="b", legend=False)
        assert ax.get_legend() is None

    def test_join_deprecation(self):

        with pytest.warns(UserWarning, match="The `join` parameter"):
            ax = pointplot(x=["a", "b", "c"], y=[1, 2, 3], join=False)
        assert ax.lines[0].get_linestyle().lower() == "none"

    def test_scale_deprecation(self):

        x, y = ["a", "b", "c"], [1, 2, 3]
        ax = pointplot(x=x, y=y, errorbar=None)
        with pytest.warns(UserWarning, match="The `scale` parameter"):
            pointplot(x=x, y=y, errorbar=None, scale=2)
        l1, l2 = ax.lines
        assert l2.get_linewidth() == 2 * l1.get_linewidth()
        assert l2.get_markersize() > l1.get_markersize()

    def test_layered_plot_clipping(self):

        x, y = ['a'], [4]
        pointplot(x=x, y=y)
        x, y = ['b'], [5]
        ax = pointplot(x=x, y=y)
        y_range = ax.viewLim.intervaly
        assert y_range[0] < 4 and y_range[1] > 5


class TestCountPlot:

    def test_empty(self):

        ax = countplot()
        assert not ax.patches

        ax = countplot(x=[])
        assert not ax.patches

    def test_labels_long(self, long_df):

        fig = mpl.figure.Figure()
        axs = fig.subplots(2)
        countplot(long_df, x="a", ax=axs[0])
        countplot(long_df, x="b", stat="percent", ax=axs[1])

        # To populate texts; only needed on older matplotlibs
        _draw_figure(fig)

        assert axs[0].get_xlabel() == "a"
        assert axs[1].get_xlabel() == "b"
        assert axs[0].get_ylabel() == "count"
        assert axs[1].get_ylabel() == "percent"

    def test_wide_data(self, wide_df):

        ax = countplot(wide_df)
        assert len(ax.patches) == len(wide_df.columns)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_y() == 0
            assert bar.get_height() == len(wide_df)
            assert bar.get_width() == approx(0.8)

    def test_flat_series(self):

        vals = ["a", "b", "c"]
        counts = [2, 1, 4]
        vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])
        ax = countplot(vals)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() == 0
            assert bar.get_y() + bar.get_height() / 2 == approx(i)
            assert bar.get_height() == approx(0.8)
            assert bar.get_width() == counts[i]

    def test_x_series(self):

        vals = ["a", "b", "c"]
        counts = [2, 1, 4]
        vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])
        ax = countplot(x=vals)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_y() == 0
            assert bar.get_height() == counts[i]
            assert bar.get_width() == approx(0.8)

    def test_y_series(self):

        vals = ["a", "b", "c"]
        counts = [2, 1, 4]
        vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])
        ax = countplot(y=vals)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() == 0
            assert bar.get_y() + bar.get_height() / 2 == approx(i)
            assert bar.get_height() == approx(0.8)
            assert bar.get_width() == counts[i]

    def test_hue_redundant(self):

        vals = ["a", "b", "c"]
        counts = [2, 1, 4]
        vals = pd.Series([x for x, n in zip(vals, counts) for _ in range(n)])

        ax = countplot(x=vals, hue=vals, saturation=1)
        for i, bar in enumerate(ax.patches):
            assert bar.get_x() + bar.get_width() / 2 == approx(i)
            assert bar.get_y() == 0
            assert bar.get_height() == counts[i]
            assert bar.get_width() == approx(0.8)
            assert same_color(bar.get_facecolor(), f"C{i}")

    def test_hue_dodged(self):

        vals = ["a", "a", "a", "b", "b", "b"]
        hue = ["x", "y", "y", "x", "x", "x"]
        counts = [1, 3, 2, 0]

        ax = countplot(x=vals, hue=hue, saturation=1, legend=False)
        for i, bar in enumerate(ax.patches):
            sign = 1 if i // 2 else -1
            assert (
                bar.get_x() + bar.get_width() / 2
                == approx(i % 2 + sign * 0.8 / 4)
            )
            assert bar.get_y() == 0
            assert bar.get_height() == counts[i]
            assert bar.get_width() == approx(0.8 / 2)
            assert same_color(bar.get_facecolor(), f"C{i // 2}")

    @pytest.mark.parametrize("stat", ["percent", "probability", "proportion"])
    def test_stat(self, long_df, stat):

        col = "a"
        order = categorical_order(long_df[col])
        expected = long_df[col].value_counts(normalize=True)
        if stat == "percent":
            expected *= 100
        ax = countplot(long_df, x=col, stat=stat)
        for i, bar in enumerate(ax.patches):
            assert bar.get_height() == approx(expected[order[i]])

    def test_xy_error(self, long_df):

        with pytest.raises(TypeError, match="Cannot pass values for both"):
            countplot(long_df, x="a", y="b")

    def test_legend_numeric_auto(self, long_df):

        ax = countplot(long_df, x="x", hue="x")
        assert len(ax.get_legend().texts) <= 6

    def test_legend_disabled(self, long_df):

        ax = countplot(long_df, x="x", hue="b", legend=False)
        assert ax.get_legend() is None

    @pytest.mark.parametrize(
        "kwargs",
        [
            dict(data="wide"),
            dict(data="wide", orient="h"),
            dict(data="flat"),
            dict(data="long", x="a"),
            dict(data=None, x="a"),
            dict(data="long", y="b"),
            dict(data="long", x="a", hue="a"),
            dict(data=None, x="a", hue="a"),
            dict(data="long", x="a", hue="b"),
            dict(data=None, x="s", hue="a"),
            dict(data="long", x="a", hue="s"),
            dict(data="null", x="a", hue="a"),
            dict(data="long", x="s", hue="a", native_scale=True),
            dict(data="long", x="d", hue="a", native_scale=True),
            dict(data="long", x="a", stat="percent"),
            dict(data="long", x="a", hue="b", stat="proportion"),
            dict(data="long", x="a", color="blue", ec="green", alpha=.5),
        ]
    )
    def test_vs_catplot(self, long_df, wide_df, null_df, flat_series, kwargs):

        kwargs = kwargs.copy()
        if kwargs["data"] == "long":
            kwargs["data"] = long_df
        elif kwargs["data"] == "wide":
            kwargs["data"] = wide_df
        elif kwargs["data"] == "flat":
            kwargs["data"] = flat_series
        elif kwargs["data"] == "null":
            kwargs["data"] = null_df
        elif kwargs["data"] is None:
            for var in ["x", "y", "hue"]:
                if var in kwargs:
                    kwargs[var] = long_df[kwargs[var]]

        ax = countplot(**kwargs)
        g = catplot(**kwargs, kind="count")

        assert_plots_equal(ax, g.ax)


class CategoricalFixture:
    """Test boxplot (also base class for things like violinplots)."""
    rs = np.random.RandomState(30)
    n_total = 60
    x = rs.randn(int(n_total / 3), 3)
    x_df = pd.DataFrame(x, columns=pd.Series(list("XYZ"), name="big"))
    y = pd.Series(rs.randn(n_total), name="y_data")
    y_perm = y.reindex(rs.choice(y.index, y.size, replace=False))
    g = pd.Series(np.repeat(list("abc"), int(n_total / 3)), name="small")
    h = pd.Series(np.tile(list("mn"), int(n_total / 2)), name="medium")
    u = pd.Series(np.tile(list("jkh"), int(n_total / 3)))
    df = pd.DataFrame(dict(y=y, g=g, h=h, u=u))
    x_df["W"] = g

    def get_box_artists(self, ax):

        if _version_predates(mpl, "3.5.0b0"):
            return ax.artists
        else:
            # Exclude labeled patches, which are for the legend
            return [p for p in ax.patches if not p.get_label()]


class TestCatPlot(CategoricalFixture):

    def test_facet_organization(self):

        g = cat.catplot(x="g", y="y", data=self.df)
        assert g.axes.shape == (1, 1)

        g = cat.catplot(x="g", y="y", col="h", data=self.df)
        assert g.axes.shape == (1, 2)

        g = cat.catplot(x="g", y="y", row="h", data=self.df)
        assert g.axes.shape == (2, 1)

        g = cat.catplot(x="g", y="y", col="u", row="h", data=self.df)
        assert g.axes.shape == (2, 3)

    def test_plot_elements(self):

        g = cat.catplot(x="g", y="y", data=self.df, kind="point")
        want_lines = 1 + self.g.unique().size
        assert len(g.ax.lines) == want_lines

        g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="point")
        want_lines = (
            len(self.g.unique()) * len(self.h.unique()) + 2 * len(self.h.unique())
        )
        assert len(g.ax.lines) == want_lines

        g = cat.catplot(x="g", y="y", data=self.df, kind="bar")
        want_elements = self.g.unique().size
        assert len(g.ax.patches) == want_elements
        assert len(g.ax.lines) == want_elements

        g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="bar")
        want_elements = self.g.nunique() * self.h.nunique()
        assert len(g.ax.patches) == (want_elements + self.h.nunique())
        assert len(g.ax.lines) == want_elements

        g = cat.catplot(x="g", data=self.df, kind="count")
        want_elements = self.g.unique().size
        assert len(g.ax.patches) == want_elements
        assert len(g.ax.lines) == 0

        g = cat.catplot(x="g", hue="h", data=self.df, kind="count")
        want_elements = self.g.nunique() * self.h.nunique() + self.h.nunique()
        assert len(g.ax.patches) == want_elements
        assert len(g.ax.lines) == 0

        g = cat.catplot(y="y", data=self.df, kind="box")
        want_artists = 1
        assert len(self.get_box_artists(g.ax)) == want_artists

        g = cat.catplot(x="g", y="y", data=self.df, kind="box")
        want_artists = self.g.unique().size
        assert len(self.get_box_artists(g.ax)) == want_artists

        g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="box")
        want_artists = self.g.nunique() * self.h.nunique()
        assert len(self.get_box_artists(g.ax)) == want_artists

        g = cat.catplot(x="g", y="y", data=self.df,
                        kind="violin", inner=None)
        want_elements = self.g.unique().size
        assert len(g.ax.collections) == want_elements

        g = cat.catplot(x="g", y="y", hue="h", data=self.df,
                        kind="violin", inner=None)
        want_elements = self.g.nunique() * self.h.nunique()
        assert len(g.ax.collections) == want_elements

        g = cat.catplot(x="g", y="y", data=self.df, kind="strip")
        want_elements = self.g.unique().size
        assert len(g.ax.collections) == want_elements
        for strip in g.ax.collections:
            assert same_color(strip.get_facecolors(), "C0")

        g = cat.catplot(x="g", y="y", hue="h", data=self.df, kind="strip")
        want_elements = self.g.nunique()
        assert len(g.ax.collections) == want_elements

    def test_bad_plot_kind_error(self):

        with pytest.raises(ValueError):
            cat.catplot(x="g", y="y", data=self.df, kind="not_a_kind")

    def test_count_x_and_y(self):

        with pytest.raises(ValueError):
            cat.catplot(x="g", y="y", data=self.df, kind="count")

    def test_plot_colors(self):

        ax = cat.barplot(x="g", y="y", data=self.df)
        g = cat.catplot(x="g", y="y", data=self.df, kind="bar")
        for p1, p2 in zip(ax.patches, g.ax.patches):
            assert p1.get_facecolor() == p2.get_facecolor()
        plt.close("all")

        ax = cat.barplot(x="g", y="y", data=self.df, color="purple")
        g = cat.catplot(x="g", y="y", data=self.df,
                        kind="bar", color="purple")
        for p1, p2 in zip(ax.patches, g.ax.patches):
            assert p1.get_facecolor() == p2.get_facecolor()
        plt.close("all")

        ax = cat.barplot(x="g", y="y", data=self.df, palette="Set2", hue="h")
        g = cat.catplot(x="g", y="y", data=self.df,
                        kind="bar", palette="Set2", hue="h")
        for p1, p2 in zip(ax.patches, g.ax.patches):
            assert p1.get_facecolor() == p2.get_facecolor()
        plt.close("all")

        ax = cat.pointplot(x="g", y="y", data=self.df)
        g = cat.catplot(x="g", y="y", data=self.df)
        for l1, l2 in zip(ax.lines, g.ax.lines):
            assert l1.get_color() == l2.get_color()
        plt.close("all")

        ax = cat.pointplot(x="g", y="y", data=self.df, color="purple")
        g = cat.catplot(x="g", y="y", data=self.df, color="purple", kind="point")
        for l1, l2 in zip(ax.lines, g.ax.lines):
            assert l1.get_color() == l2.get_color()
        plt.close("all")

        ax = cat.pointplot(x="g", y="y", data=self.df, palette="Set2", hue="h")
        g = cat.catplot(
            x="g", y="y", data=self.df, palette="Set2", hue="h", kind="point"
        )
        for l1, l2 in zip(ax.lines, g.ax.lines):
            assert l1.get_color() == l2.get_color()
        plt.close("all")

    def test_ax_kwarg_removal(self):

        f, ax = plt.subplots()
        with pytest.warns(UserWarning, match="catplot is a figure-level"):
            g = cat.catplot(x="g", y="y", data=self.df, ax=ax)
        assert len(ax.collections) == 0
        assert len(g.ax.collections) > 0

    def test_share_xy(self):

        # Test default behavior works
        g = cat.catplot(x="g", y="y", col="g", data=self.df, sharex=True)
        for ax in g.axes.flat:
            assert len(ax.collections) == len(self.df.g.unique())

        g = cat.catplot(x="y", y="g", col="g", data=self.df, sharey=True)
        for ax in g.axes.flat:
            assert len(ax.collections) == len(self.df.g.unique())

        # Test unsharing works
        g = cat.catplot(
            x="g", y="y", col="g", data=self.df, sharex=False, kind="bar",
        )
        for ax in g.axes.flat:
            assert len(ax.patches) == 1

        g = cat.catplot(
            x="y", y="g", col="g", data=self.df, sharey=False, kind="bar",
        )
        for ax in g.axes.flat:
            assert len(ax.patches) == 1

        g = cat.catplot(
            x="g", y="y", col="g", data=self.df, sharex=False, color="b"
        )
        for ax in g.axes.flat:
            assert ax.get_xlim() == (-.5, .5)

        g = cat.catplot(
            x="y", y="g", col="g", data=self.df, sharey=False, color="r"
        )
        for ax in g.axes.flat:
            assert ax.get_ylim() == (.5, -.5)

        # Make sure order is used if given, regardless of sharex value
        order = self.df.g.unique()
        g = cat.catplot(x="g", y="y", col="g", data=self.df, sharex=False, order=order)
        for ax in g.axes.flat:
            assert len(ax.collections) == len(self.df.g.unique())

        g = cat.catplot(x="y", y="g", col="g", data=self.df, sharey=False, order=order)
        for ax in g.axes.flat:
            assert len(ax.collections) == len(self.df.g.unique())

    def test_facetgrid_data(self, long_df):

        g1 = catplot(data=long_df, x="a", y="y", col="c")
        assert g1.data is long_df

        g2 = catplot(x=long_df["a"], y=long_df["y"], col=long_df["c"])
        assert g2.data.equals(long_df[["a", "y", "c"]])

    @pytest.mark.parametrize("var", ["col", "row"])
    def test_array_faceter(self, long_df, var):

        g1 = catplot(data=long_df, x="y", **{var: "a"})
        g2 = catplot(data=long_df, x="y", **{var: long_df["a"].to_numpy()})

        for ax1, ax2 in zip(g1.axes.flat, g2.axes.flat):
            assert_plots_equal(ax1, ax2)

    def test_invalid_kind(self, long_df):

        with pytest.raises(ValueError, match="Invalid `kind`: 'wrong'"):
            catplot(long_df, kind="wrong")

    def test_legend_with_auto(self):

        g1 = catplot(self.df, x="g", y="y", hue="g", legend='auto')
        assert g1._legend is None

        g2 = catplot(self.df, x="g", y="y", hue="g", legend=True)
        assert g2._legend is not None

    def test_weights_warning(self, long_df):

        with pytest.warns(UserWarning, match="The `weights` parameter"):
            g = catplot(long_df, x="a", y="y", weights="z")
        assert g.ax is not None


class TestBeeswarm:

    def test_could_overlap(self):

        p = Beeswarm()
        neighbors = p.could_overlap(
            (1, 1, .5),
            [(0, 0, .5),
             (1, .1, .2),
             (.5, .5, .5)]
        )
        assert_array_equal(neighbors, [(.5, .5, .5)])

    def test_position_candidates(self):

        p = Beeswarm()
        xy_i = (0, 1, .5)
        neighbors = [(0, 1, .5), (0, 1.5, .5)]
        candidates = p.position_candidates(xy_i, neighbors)
        dx1 = 1.05
        dx2 = np.sqrt(1 - .5 ** 2) * 1.05
        assert_array_equal(
            candidates,
            [(0, 1, .5), (-dx1, 1, .5), (dx1, 1, .5), (dx2, 1, .5), (-dx2, 1, .5)]
        )

    def test_find_first_non_overlapping_candidate(self):

        p = Beeswarm()
        candidates = [(.5, 1, .5), (1, 1, .5), (1.5, 1, .5)]
        neighbors = np.array([(0, 1, .5)])

        first = p.first_non_overlapping_candidate(candidates, neighbors)
        assert_array_equal(first, (1, 1, .5))

    def test_beeswarm(self, long_df):

        p = Beeswarm()
        data = long_df["y"]
        d = data.diff().mean() * 1.5
        x = np.zeros(data.size)
        y = np.sort(data)
        r = np.full_like(y, d)
        orig_xyr = np.c_[x, y, r]
        swarm = p.beeswarm(orig_xyr)[:, :2]
        dmat = np.sqrt(np.sum(np.square(swarm[:, np.newaxis] - swarm), axis=-1))
        triu = dmat[np.triu_indices_from(dmat, 1)]
        assert_array_less(d, triu)
        assert_array_equal(y, swarm[:, 1])

    def test_add_gutters(self):

        p = Beeswarm(width=1)

        points = np.zeros(10)
        t_fwd = t_inv = lambda x: x
        assert_array_equal(points, p.add_gutters(points, 0, t_fwd, t_inv))

        points = np.array([0, -1, .4, .8])
        msg = r"50.0% of the points cannot be placed.+$"
        with pytest.warns(UserWarning, match=msg):
            new_points = p.add_gutters(points, 0, t_fwd, t_inv)
        assert_array_equal(new_points, np.array([0, -.5, .4, .5]))


class TestBoxPlotContainer:

    @pytest.fixture
    def container(self, wide_array):

        ax = mpl.figure.Figure().subplots()
        artist_dict = ax.boxplot(wide_array)
        return BoxPlotContainer(artist_dict)

    def test_repr(self, container, wide_array):

        n = wide_array.shape[1]
        assert str(container) == f""

    def test_iteration(self, container):
        for artist_tuple in container:
            for attr in ["box", "median", "whiskers", "caps", "fliers", "mean"]:
                assert hasattr(artist_tuple, attr)

    def test_label(self, container):

        label = "a box plot"
        container.set_label(label)
        assert container.get_label() == label

    def test_children(self, container):

        children = container.get_children()
        for child in children:
            assert isinstance(child, mpl.artist.Artist)


================================================
FILE: tests/test_distributions.py
================================================
import itertools
import warnings

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import to_rgb, to_rgba

import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn import distributions as dist
from seaborn.palettes import (
    color_palette,
    light_palette,
)
from seaborn._base import (
    categorical_order,
)
from seaborn._statistics import (
    KDE,
    Histogram,
    _no_scipy,
)
from seaborn.distributions import (
    _DistributionPlotter,
    displot,
    distplot,
    histplot,
    ecdfplot,
    kdeplot,
    rugplot,
)
from seaborn.utils import _version_predates
from seaborn.axisgrid import FacetGrid
from seaborn._testing import (
    assert_plots_equal,
    assert_legends_equal,
    assert_colors_equal,
)


def get_contour_coords(c, filter_empty=False):
    """Provide compatability for change in contour artist types."""
    if isinstance(c, mpl.collections.LineCollection):
        # See https://github.com/matplotlib/matplotlib/issues/20906
        return c.get_segments()
    elif isinstance(c, (mpl.collections.PathCollection, mpl.contour.QuadContourSet)):
        return [
            p.vertices[:np.argmax(p.codes) + 1] for p in c.get_paths()
            if len(p) or not filter_empty
        ]


def get_contour_color(c):
    """Provide compatability for change in contour artist types."""
    if isinstance(c, mpl.collections.LineCollection):
        # See https://github.com/matplotlib/matplotlib/issues/20906
        return c.get_color()
    elif isinstance(c, (mpl.collections.PathCollection, mpl.contour.QuadContourSet)):
        if c.get_facecolor().size:
            return c.get_facecolor()
        else:
            return c.get_edgecolor()


class TestDistPlot:

    rs = np.random.RandomState(0)
    x = rs.randn(100)

    def test_hist_bins(self):

        fd_edges = np.histogram_bin_edges(self.x, "fd")
        with pytest.warns(UserWarning):
            ax = distplot(self.x)
        for edge, bar in zip(fd_edges, ax.patches):
            assert pytest.approx(edge) == bar.get_x()

        plt.close(ax.figure)
        n = 25
        n_edges = np.histogram_bin_edges(self.x, n)
        with pytest.warns(UserWarning):
            ax = distplot(self.x, bins=n)
        for edge, bar in zip(n_edges, ax.patches):
            assert pytest.approx(edge) == bar.get_x()

    def test_elements(self):

        with pytest.warns(UserWarning):

            n = 10
            ax = distplot(self.x, bins=n,
                          hist=True, kde=False, rug=False, fit=None)
            assert len(ax.patches) == 10
            assert len(ax.lines) == 0
            assert len(ax.collections) == 0

            plt.close(ax.figure)
            ax = distplot(self.x,
                          hist=False, kde=True, rug=False, fit=None)
            assert len(ax.patches) == 0
            assert len(ax.lines) == 1
            assert len(ax.collections) == 0

            plt.close(ax.figure)
            ax = distplot(self.x,
                          hist=False, kde=False, rug=True, fit=None)
            assert len(ax.patches) == 0
            assert len(ax.lines) == 0
            assert len(ax.collections) == 1

            class Norm:
                """Dummy object that looks like a scipy RV"""
                def fit(self, x):
                    return ()

                def pdf(self, x, *params):
                    return np.zeros_like(x)

            plt.close(ax.figure)
            ax = distplot(
                self.x, hist=False, kde=False, rug=False, fit=Norm())
            assert len(ax.patches) == 0
            assert len(ax.lines) == 1
            assert len(ax.collections) == 0

    def test_distplot_with_nans(self):

        f, (ax1, ax2) = plt.subplots(2)
        x_null = np.append(self.x, [np.nan])

        with pytest.warns(UserWarning):
            distplot(self.x, ax=ax1)
            distplot(x_null, ax=ax2)

        line1 = ax1.lines[0]
        line2 = ax2.lines[0]
        assert np.array_equal(line1.get_xydata(), line2.get_xydata())

        for bar1, bar2 in zip(ax1.patches, ax2.patches):
            assert bar1.get_xy() == bar2.get_xy()
            assert bar1.get_height() == bar2.get_height()


class SharedAxesLevelTests:

    def test_color(self, long_df, **kwargs):

        ax = plt.figure().subplots()
        self.func(data=long_df, x="y", ax=ax, **kwargs)
        assert_colors_equal(self.get_last_color(ax, **kwargs), "C0", check_alpha=False)

        ax = plt.figure().subplots()
        self.func(data=long_df, x="y", ax=ax, **kwargs)
        self.func(data=long_df, x="y", ax=ax, **kwargs)
        assert_colors_equal(self.get_last_color(ax, **kwargs), "C1", check_alpha=False)

        ax = plt.figure().subplots()
        self.func(data=long_df, x="y", color="C2", ax=ax, **kwargs)
        assert_colors_equal(self.get_last_color(ax, **kwargs), "C2", check_alpha=False)


class TestRugPlot(SharedAxesLevelTests):

    func = staticmethod(rugplot)

    def get_last_color(self, ax, **kwargs):

        return ax.collections[-1].get_color()

    def assert_rug_equal(self, a, b):

        assert_array_equal(a.get_segments(), b.get_segments())

    @pytest.mark.parametrize("variable", ["x", "y"])
    def test_long_data(self, long_df, variable):

        vector = long_df[variable]
        vectors = [
            variable, vector, np.asarray(vector), vector.to_list(),
        ]

        f, ax = plt.subplots()
        for vector in vectors:
            rugplot(data=long_df, **{variable: vector})

        for a, b in itertools.product(ax.collections, ax.collections):
            self.assert_rug_equal(a, b)

    def test_bivariate_data(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        rugplot(data=long_df, x="x", y="y", ax=ax1)
        rugplot(data=long_df, x="x", ax=ax2)
        rugplot(data=long_df, y="y", ax=ax2)

        self.assert_rug_equal(ax1.collections[0], ax2.collections[0])
        self.assert_rug_equal(ax1.collections[1], ax2.collections[1])

    def test_wide_vs_long_data(self, wide_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)
        rugplot(data=wide_df, ax=ax1)
        for col in wide_df:
            rugplot(data=wide_df, x=col, ax=ax2)

        wide_segments = np.sort(
            np.array(ax1.collections[0].get_segments())
        )
        long_segments = np.sort(
            np.concatenate([c.get_segments() for c in ax2.collections])
        )

        assert_array_equal(wide_segments, long_segments)

    def test_flat_vector(self, long_df):

        f, ax = plt.subplots()
        rugplot(data=long_df["x"])
        rugplot(x=long_df["x"])
        self.assert_rug_equal(*ax.collections)

    def test_datetime_data(self, long_df):

        ax = rugplot(data=long_df["t"])
        vals = np.stack(ax.collections[0].get_segments())[:, 0, 0]
        assert_array_equal(vals, mpl.dates.date2num(long_df["t"]))

    def test_empty_data(self):

        ax = rugplot(x=[])
        assert not ax.collections

    def test_a_deprecation(self, flat_series):

        f, ax = plt.subplots()

        with pytest.warns(UserWarning):
            rugplot(a=flat_series)
        rugplot(x=flat_series)

        self.assert_rug_equal(*ax.collections)

    @pytest.mark.parametrize("variable", ["x", "y"])
    def test_axis_deprecation(self, flat_series, variable):

        f, ax = plt.subplots()

        with pytest.warns(UserWarning):
            rugplot(flat_series, axis=variable)
        rugplot(**{variable: flat_series})

        self.assert_rug_equal(*ax.collections)

    def test_vertical_deprecation(self, flat_series):

        f, ax = plt.subplots()

        with pytest.warns(UserWarning):
            rugplot(flat_series, vertical=True)
        rugplot(y=flat_series)

        self.assert_rug_equal(*ax.collections)

    def test_rug_data(self, flat_array):

        height = .05
        ax = rugplot(x=flat_array, height=height)
        segments = np.stack(ax.collections[0].get_segments())

        n = flat_array.size
        assert_array_equal(segments[:, 0, 1], np.zeros(n))
        assert_array_equal(segments[:, 1, 1], np.full(n, height))
        assert_array_equal(segments[:, 1, 0], flat_array)

    def test_rug_colors(self, long_df):

        ax = rugplot(data=long_df, x="x", hue="a")

        order = categorical_order(long_df["a"])
        palette = color_palette()

        expected_colors = np.ones((len(long_df), 4))
        for i, val in enumerate(long_df["a"]):
            expected_colors[i, :3] = palette[order.index(val)]

        assert_array_equal(ax.collections[0].get_color(), expected_colors)

    def test_expand_margins(self, flat_array):

        f, ax = plt.subplots()
        x1, y1 = ax.margins()
        rugplot(x=flat_array, expand_margins=False)
        x2, y2 = ax.margins()
        assert x1 == x2
        assert y1 == y2

        f, ax = plt.subplots()
        x1, y1 = ax.margins()
        height = .05
        rugplot(x=flat_array, height=height)
        x2, y2 = ax.margins()
        assert x1 == x2
        assert y1 + height * 2 == pytest.approx(y2)

    def test_multiple_rugs(self):

        values = np.linspace(start=0, stop=1, num=5)
        ax = rugplot(x=values)
        ylim = ax.get_ylim()

        rugplot(x=values, ax=ax, expand_margins=False)

        assert ylim == ax.get_ylim()

    def test_matplotlib_kwargs(self, flat_series):

        lw = 2
        alpha = .2
        ax = rugplot(y=flat_series, linewidth=lw, alpha=alpha)
        rug = ax.collections[0]
        assert np.all(rug.get_alpha() == alpha)
        assert np.all(rug.get_linewidth() == lw)

    def test_axis_labels(self, flat_series):

        ax = rugplot(x=flat_series)
        assert ax.get_xlabel() == flat_series.name
        assert not ax.get_ylabel()

    def test_log_scale(self, long_df):

        ax1, ax2 = plt.figure().subplots(2)

        ax2.set_xscale("log")

        rugplot(data=long_df, x="z", ax=ax1)
        rugplot(data=long_df, x="z", ax=ax2)

        rug1 = np.stack(ax1.collections[0].get_segments())
        rug2 = np.stack(ax2.collections[0].get_segments())

        assert_array_almost_equal(rug1, rug2)


class TestKDEPlotUnivariate(SharedAxesLevelTests):

    func = staticmethod(kdeplot)

    def get_last_color(self, ax, fill=True):

        if fill:
            return ax.collections[-1].get_facecolor()
        else:
            return ax.lines[-1].get_color()

    @pytest.mark.parametrize("fill", [True, False])
    def test_color(self, long_df, fill):

        super().test_color(long_df, fill=fill)

        if fill:

            ax = plt.figure().subplots()
            self.func(data=long_df, x="y", facecolor="C3", fill=True, ax=ax)
            assert_colors_equal(self.get_last_color(ax), "C3", check_alpha=False)

            ax = plt.figure().subplots()
            self.func(data=long_df, x="y", fc="C4", fill=True, ax=ax)
            assert_colors_equal(self.get_last_color(ax), "C4", check_alpha=False)

    @pytest.mark.parametrize(
        "variable", ["x", "y"],
    )
    def test_long_vectors(self, long_df, variable):

        vector = long_df[variable]
        vectors = [
            variable, vector, vector.to_numpy(), vector.to_list(),
        ]

        f, ax = plt.subplots()
        for vector in vectors:
            kdeplot(data=long_df, **{variable: vector})

        xdata = [l.get_xdata() for l in ax.lines]
        for a, b in itertools.product(xdata, xdata):
            assert_array_equal(a, b)

        ydata = [l.get_ydata() for l in ax.lines]
        for a, b in itertools.product(ydata, ydata):
            assert_array_equal(a, b)

    def test_wide_vs_long_data(self, wide_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)
        kdeplot(data=wide_df, ax=ax1, common_norm=False, common_grid=False)
        for col in wide_df:
            kdeplot(data=wide_df, x=col, ax=ax2)

        for l1, l2 in zip(ax1.lines[::-1], ax2.lines):
            assert_array_equal(l1.get_xydata(), l2.get_xydata())

    def test_flat_vector(self, long_df):

        f, ax = plt.subplots()
        kdeplot(data=long_df["x"])
        kdeplot(x=long_df["x"])
        assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())

    def test_empty_data(self):

        ax = kdeplot(x=[])
        assert not ax.lines

    def test_singular_data(self):

        with pytest.warns(UserWarning):
            ax = kdeplot(x=np.ones(10))
        assert not ax.lines

        with pytest.warns(UserWarning):
            ax = kdeplot(x=[5])
        assert not ax.lines

        with pytest.warns(UserWarning):
            # https://github.com/mwaskom/seaborn/issues/2762
            ax = kdeplot(x=[1929245168.06679] * 18)
        assert not ax.lines

        with warnings.catch_warnings():
            warnings.simplefilter("error", UserWarning)
            ax = kdeplot(x=[5], warn_singular=False)
        assert not ax.lines

    def test_variable_assignment(self, long_df):

        f, ax = plt.subplots()
        kdeplot(data=long_df, x="x", fill=True)
        kdeplot(data=long_df, y="x", fill=True)

        v0 = ax.collections[0].get_paths()[0].vertices
        v1 = ax.collections[1].get_paths()[0].vertices[:, [1, 0]]

        assert_array_equal(v0, v1)

    def test_vertical_deprecation(self, long_df):

        f, ax = plt.subplots()
        kdeplot(data=long_df, y="x")

        with pytest.warns(UserWarning):
            kdeplot(data=long_df, x="x", vertical=True)

        assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())

    def test_bw_deprecation(self, long_df):

        f, ax = plt.subplots()
        kdeplot(data=long_df, x="x", bw_method="silverman")

        with pytest.warns(UserWarning):
            kdeplot(data=long_df, x="x", bw="silverman")

        assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())

    def test_kernel_deprecation(self, long_df):

        f, ax = plt.subplots()
        kdeplot(data=long_df, x="x")

        with pytest.warns(UserWarning):
            kdeplot(data=long_df, x="x", kernel="epi")

        assert_array_equal(ax.lines[0].get_xydata(), ax.lines[1].get_xydata())

    def test_shade_deprecation(self, long_df):

        f, ax = plt.subplots()
        with pytest.warns(FutureWarning):
            kdeplot(data=long_df, x="x", shade=True)
        kdeplot(data=long_df, x="x", fill=True)
        fill1, fill2 = ax.collections
        assert_array_equal(
            fill1.get_paths()[0].vertices, fill2.get_paths()[0].vertices
        )

    @pytest.mark.parametrize("multiple", ["layer", "stack", "fill"])
    def test_hue_colors(self, long_df, multiple):

        ax = kdeplot(
            data=long_df, x="x", hue="a",
            multiple=multiple,
            fill=True, legend=False
        )

        # Note that hue order is reversed in the plot
        lines = ax.lines[::-1]
        fills = ax.collections[::-1]

        palette = color_palette()

        for line, fill, color in zip(lines, fills, palette):
            assert_colors_equal(line.get_color(), color)
            assert_colors_equal(fill.get_facecolor(), to_rgba(color, .25))

    def test_hue_stacking(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        kdeplot(
            data=long_df, x="x", hue="a",
            multiple="layer", common_grid=True,
            legend=False, ax=ax1,
        )
        kdeplot(
            data=long_df, x="x", hue="a",
            multiple="stack", fill=False,
            legend=False, ax=ax2,
        )

        layered_densities = np.stack([
            l.get_ydata() for l in ax1.lines
        ])
        stacked_densities = np.stack([
            l.get_ydata() for l in ax2.lines
        ])

        assert_array_equal(layered_densities.cumsum(axis=0), stacked_densities)

    def test_hue_filling(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        kdeplot(
            data=long_df, x="x", hue="a",
            multiple="layer", common_grid=True,
            legend=False, ax=ax1,
        )
        kdeplot(
            data=long_df, x="x", hue="a",
            multiple="fill", fill=False,
            legend=False, ax=ax2,
        )

        layered = np.stack([l.get_ydata() for l in ax1.lines])
        filled = np.stack([l.get_ydata() for l in ax2.lines])

        assert_array_almost_equal(
            (layered / layered.sum(axis=0)).cumsum(axis=0),
            filled,
        )

    @pytest.mark.parametrize("multiple", ["stack", "fill"])
    def test_fill_default(self, long_df, multiple):

        ax = kdeplot(
            data=long_df, x="x", hue="a", multiple=multiple, fill=None
        )

        assert len(ax.collections) > 0

    @pytest.mark.parametrize("multiple", ["layer", "stack", "fill"])
    def test_fill_nondefault(self, long_df, multiple):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        kws = dict(data=long_df, x="x", hue="a")
        kdeplot(**kws, multiple=multiple, fill=False, ax=ax1)
        kdeplot(**kws, multiple=multiple, fill=True, ax=ax2)

        assert len(ax1.collections) == 0
        assert len(ax2.collections) > 0

    def test_color_cycle_interaction(self, flat_series):

        color = (.2, 1, .6)

        f, ax = plt.subplots()
        kdeplot(flat_series)
        kdeplot(flat_series)
        assert_colors_equal(ax.lines[0].get_color(), "C0")
        assert_colors_equal(ax.lines[1].get_color(), "C1")
        plt.close(f)

        f, ax = plt.subplots()
        kdeplot(flat_series, color=color)
        kdeplot(flat_series)
        assert_colors_equal(ax.lines[0].get_color(), color)
        assert_colors_equal(ax.lines[1].get_color(), "C0")
        plt.close(f)

        f, ax = plt.subplots()
        kdeplot(flat_series, fill=True)
        kdeplot(flat_series, fill=True)
        assert_colors_equal(ax.collections[0].get_facecolor(), to_rgba("C0", .25))
        assert_colors_equal(ax.collections[1].get_facecolor(), to_rgba("C1", .25))
        plt.close(f)

    @pytest.mark.parametrize("fill", [True, False])
    def test_artist_color(self, long_df, fill):

        color = (.2, 1, .6)
        alpha = .5

        f, ax = plt.subplots()

        kdeplot(long_df["x"], fill=fill, color=color)
        if fill:
            artist_color = ax.collections[-1].get_facecolor().squeeze()
        else:
            artist_color = ax.lines[-1].get_color()
        default_alpha = .25 if fill else 1
        assert_colors_equal(artist_color, to_rgba(color, default_alpha))

        kdeplot(long_df["x"], fill=fill, color=color, alpha=alpha)
        if fill:
            artist_color = ax.collections[-1].get_facecolor().squeeze()
        else:
            artist_color = ax.lines[-1].get_color()
        assert_colors_equal(artist_color, to_rgba(color, alpha))

    def test_datetime_scale(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)
        kdeplot(x=long_df["t"], fill=True, ax=ax1)
        kdeplot(x=long_df["t"], fill=False, ax=ax2)
        assert ax1.get_xlim() == ax2.get_xlim()

    def test_multiple_argument_check(self, long_df):

        with pytest.raises(ValueError, match="`multiple` must be"):
            kdeplot(data=long_df, x="x", hue="a", multiple="bad_input")

    def test_cut(self, rng):

        x = rng.normal(0, 3, 1000)

        f, ax = plt.subplots()
        kdeplot(x=x, cut=0, legend=False)

        xdata_0 = ax.lines[0].get_xdata()
        assert xdata_0.min() == x.min()
        assert xdata_0.max() == x.max()

        kdeplot(x=x, cut=2, legend=False)

        xdata_2 = ax.lines[1].get_xdata()
        assert xdata_2.min() < xdata_0.min()
        assert xdata_2.max() > xdata_0.max()

        assert len(xdata_0) == len(xdata_2)

    def test_clip(self, rng):

        x = rng.normal(0, 3, 1000)

        clip = -1, 1
        ax = kdeplot(x=x, clip=clip)

        xdata = ax.lines[0].get_xdata()

        assert xdata.min() >= clip[0]
        assert xdata.max() <= clip[1]

    def test_line_is_density(self, long_df):

        ax = kdeplot(data=long_df, x="x", cut=5)
        x, y = ax.lines[0].get_xydata().T
        assert integrate(y, x) == pytest.approx(1)

    @pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
    def test_cumulative(self, long_df):

        ax = kdeplot(data=long_df, x="x", cut=5, cumulative=True)
        y = ax.lines[0].get_ydata()
        assert y[0] == pytest.approx(0)
        assert y[-1] == pytest.approx(1)

    @pytest.mark.skipif(not _no_scipy, reason="Test requires scipy's absence")
    def test_cumulative_requires_scipy(self, long_df):

        with pytest.raises(RuntimeError):
            kdeplot(data=long_df, x="x", cut=5, cumulative=True)

    def test_common_norm(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        kdeplot(
            data=long_df, x="x", hue="c", common_norm=True, cut=10, ax=ax1
        )
        kdeplot(
            data=long_df, x="x", hue="c", common_norm=False, cut=10, ax=ax2
        )

        total_area = 0
        for line in ax1.lines:
            xdata, ydata = line.get_xydata().T
            total_area += integrate(ydata, xdata)
        assert total_area == pytest.approx(1)

        for line in ax2.lines:
            xdata, ydata = line.get_xydata().T
            assert integrate(ydata, xdata) == pytest.approx(1)

    def test_common_grid(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        order = "a", "b", "c"

        kdeplot(
            data=long_df, x="x", hue="a", hue_order=order,
            common_grid=False, cut=0, ax=ax1,
        )
        kdeplot(
            data=long_df, x="x", hue="a", hue_order=order,
            common_grid=True, cut=0, ax=ax2,
        )

        for line, level in zip(ax1.lines[::-1], order):
            xdata = line.get_xdata()
            assert xdata.min() == long_df.loc[long_df["a"] == level, "x"].min()
            assert xdata.max() == long_df.loc[long_df["a"] == level, "x"].max()

        for line in ax2.lines:
            xdata = line.get_xdata().T
            assert xdata.min() == long_df["x"].min()
            assert xdata.max() == long_df["x"].max()

    def test_bw_method(self, long_df):

        f, ax = plt.subplots()
        kdeplot(data=long_df, x="x", bw_method=0.2, legend=False)
        kdeplot(data=long_df, x="x", bw_method=1.0, legend=False)
        kdeplot(data=long_df, x="x", bw_method=3.0, legend=False)

        l1, l2, l3 = ax.lines

        assert (
            np.abs(np.diff(l1.get_ydata())).mean()
            > np.abs(np.diff(l2.get_ydata())).mean()
        )

        assert (
            np.abs(np.diff(l2.get_ydata())).mean()
            > np.abs(np.diff(l3.get_ydata())).mean()
        )

    def test_bw_adjust(self, long_df):

        f, ax = plt.subplots()
        kdeplot(data=long_df, x="x", bw_adjust=0.2, legend=False)
        kdeplot(data=long_df, x="x", bw_adjust=1.0, legend=False)
        kdeplot(data=long_df, x="x", bw_adjust=3.0, legend=False)

        l1, l2, l3 = ax.lines

        assert (
            np.abs(np.diff(l1.get_ydata())).mean()
            > np.abs(np.diff(l2.get_ydata())).mean()
        )

        assert (
            np.abs(np.diff(l2.get_ydata())).mean()
            > np.abs(np.diff(l3.get_ydata())).mean()
        )

    def test_log_scale_implicit(self, rng):

        x = rng.lognormal(0, 1, 100)

        f, (ax1, ax2) = plt.subplots(ncols=2)
        ax1.set_xscale("log")

        kdeplot(x=x, ax=ax1)
        kdeplot(x=x, ax=ax1)

        xdata_log = ax1.lines[0].get_xdata()
        assert (xdata_log > 0).all()
        assert (np.diff(xdata_log, 2) > 0).all()
        assert np.allclose(np.diff(np.log(xdata_log), 2), 0)

        f, ax = plt.subplots()
        ax.set_yscale("log")
        kdeplot(y=x, ax=ax)
        assert_array_equal(ax.lines[0].get_xdata(), ax1.lines[0].get_ydata())

    def test_log_scale_explicit(self, rng):

        x = rng.lognormal(0, 1, 100)

        f, (ax1, ax2, ax3) = plt.subplots(ncols=3)

        ax1.set_xscale("log")
        kdeplot(x=x, ax=ax1)
        kdeplot(x=x, log_scale=True, ax=ax2)
        kdeplot(x=x, log_scale=10, ax=ax3)

        for ax in f.axes:
            assert ax.get_xscale() == "log"

        supports = [ax.lines[0].get_xdata() for ax in f.axes]
        for a, b in itertools.product(supports, supports):
            assert_array_equal(a, b)

        densities = [ax.lines[0].get_ydata() for ax in f.axes]
        for a, b in itertools.product(densities, densities):
            assert_array_equal(a, b)

        f, ax = plt.subplots()
        kdeplot(y=x, log_scale=True, ax=ax)
        assert ax.get_yscale() == "log"

    def test_log_scale_with_hue(self, rng):

        data = rng.lognormal(0, 1, 50), rng.lognormal(0, 2, 100)
        ax = kdeplot(data=data, log_scale=True, common_grid=True)
        assert_array_equal(ax.lines[0].get_xdata(), ax.lines[1].get_xdata())

    def test_log_scale_normalization(self, rng):

        x = rng.lognormal(0, 1, 100)
        ax = kdeplot(x=x, log_scale=True, cut=10)
        xdata, ydata = ax.lines[0].get_xydata().T
        integral = integrate(ydata, np.log10(xdata))
        assert integral == pytest.approx(1)

    def test_weights(self):

        x = [1, 2]
        weights = [2, 1]

        ax = kdeplot(x=x, weights=weights, bw_method=.1)

        xdata, ydata = ax.lines[0].get_xydata().T

        y1 = ydata[np.abs(xdata - 1).argmin()]
        y2 = ydata[np.abs(xdata - 2).argmin()]

        assert y1 == pytest.approx(2 * y2)

    def test_weight_norm(self, rng):

        vals = rng.normal(0, 1, 50)
        x = np.concatenate([vals, vals])
        w = np.repeat([1, 2], 50)
        ax = kdeplot(x=x, weights=w, hue=w, common_norm=True)

        # Recall that artists are added in reverse of hue order
        x1, y1 = ax.lines[0].get_xydata().T
        x2, y2 = ax.lines[1].get_xydata().T

        assert integrate(y1, x1) == pytest.approx(2 * integrate(y2, x2))

    def test_sticky_edges(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        kdeplot(data=long_df, x="x", fill=True, ax=ax1)
        assert ax1.collections[0].sticky_edges.y[:] == [0, np.inf]

        kdeplot(
            data=long_df, x="x", hue="a", multiple="fill", fill=True, ax=ax2
        )
        assert ax2.collections[0].sticky_edges.y[:] == [0, 1]

    def test_line_kws(self, flat_array):

        lw = 3
        color = (.2, .5, .8)
        ax = kdeplot(x=flat_array, linewidth=lw, color=color)
        line, = ax.lines
        assert line.get_linewidth() == lw
        assert_colors_equal(line.get_color(), color)

    def test_input_checking(self, long_df):

        err = "The x variable is categorical,"
        with pytest.raises(TypeError, match=err):
            kdeplot(data=long_df, x="a")

    def test_axis_labels(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        kdeplot(data=long_df, x="x", ax=ax1)
        assert ax1.get_xlabel() == "x"
        assert ax1.get_ylabel() == "Density"

        kdeplot(data=long_df, y="y", ax=ax2)
        assert ax2.get_xlabel() == "Density"
        assert ax2.get_ylabel() == "y"

    def test_legend(self, long_df):

        ax = kdeplot(data=long_df, x="x", hue="a")

        assert ax.legend_.get_title().get_text() == "a"

        legend_labels = ax.legend_.get_texts()
        order = categorical_order(long_df["a"])
        for label, level in zip(legend_labels, order):
            assert label.get_text() == level

        legend_artists = ax.legend_.findobj(mpl.lines.Line2D)
        if _version_predates(mpl, "3.5.0b0"):
            # https://github.com/matplotlib/matplotlib/pull/20699
            legend_artists = legend_artists[::2]
        palette = color_palette()
        for artist, color in zip(legend_artists, palette):
            assert_colors_equal(artist.get_color(), color)

        ax.clear()

        kdeplot(data=long_df, x="x", hue="a", legend=False)

        assert ax.legend_ is None

    def test_replaced_kws(self, long_df):
        with pytest.raises(TypeError, match=r"`data2` has been removed"):
            kdeplot(data=long_df, x="x", data2="y")


class TestKDEPlotBivariate:

    def test_long_vectors(self, long_df):

        ax1 = kdeplot(data=long_df, x="x", y="y")

        x = long_df["x"]
        x_values = [x, x.to_numpy(), x.to_list()]

        y = long_df["y"]
        y_values = [y, y.to_numpy(), y.to_list()]

        for x, y in zip(x_values, y_values):
            f, ax2 = plt.subplots()
            kdeplot(x=x, y=y, ax=ax2)

            for c1, c2 in zip(ax1.collections, ax2.collections):
                assert_array_equal(c1.get_offsets(), c2.get_offsets())

    def test_singular_data(self):

        with pytest.warns(UserWarning):
            ax = dist.kdeplot(x=np.ones(10), y=np.arange(10))
        assert not ax.lines

        with pytest.warns(UserWarning):
            ax = dist.kdeplot(x=[5], y=[6])
        assert not ax.lines

        with pytest.warns(UserWarning):
            ax = kdeplot(x=[1929245168.06679] * 18, y=np.arange(18))
        assert not ax.lines

        with warnings.catch_warnings():
            warnings.simplefilter("error", UserWarning)
            ax = kdeplot(x=[5], y=[7], warn_singular=False)
        assert not ax.lines

    def test_fill_artists(self, long_df):

        for fill in [True, False]:
            f, ax = plt.subplots()
            kdeplot(data=long_df, x="x", y="y", hue="c", fill=fill)
            for c in ax.collections:
                if not _version_predates(mpl, "3.8.0rc1"):
                    assert isinstance(c, mpl.contour.QuadContourSet)
                elif fill or not _version_predates(mpl, "3.5.0b0"):
                    assert isinstance(c, mpl.collections.PathCollection)
                else:
                    assert isinstance(c, mpl.collections.LineCollection)

    def test_common_norm(self, rng):

        hue = np.repeat(["a", "a", "a", "b"], 40)
        x, y = rng.multivariate_normal([0, 0], [(.2, .5), (.5, 2)], len(hue)).T
        x[hue == "a"] -= 2
        x[hue == "b"] += 2

        f, (ax1, ax2) = plt.subplots(ncols=2)
        kdeplot(x=x, y=y, hue=hue, common_norm=True, ax=ax1)
        kdeplot(x=x, y=y, hue=hue, common_norm=False, ax=ax2)

        n_seg_1 = sum(len(get_contour_coords(c, True)) for c in ax1.collections)
        n_seg_2 = sum(len(get_contour_coords(c, True)) for c in ax2.collections)
        assert n_seg_2 > n_seg_1

    def test_log_scale(self, rng):

        x = rng.lognormal(0, 1, 100)
        y = rng.uniform(0, 1, 100)

        levels = .2, .5, 1

        f, ax = plt.subplots()
        kdeplot(x=x, y=y, log_scale=True, levels=levels, ax=ax)
        assert ax.get_xscale() == "log"
        assert ax.get_yscale() == "log"

        f, (ax1, ax2) = plt.subplots(ncols=2)
        kdeplot(x=x, y=y, log_scale=(10, False), levels=levels, ax=ax1)
        assert ax1.get_xscale() == "log"
        assert ax1.get_yscale() == "linear"

        p = _DistributionPlotter()
        kde = KDE()
        density, (xx, yy) = kde(np.log10(x), y)
        levels = p._quantile_to_level(density, levels)
        ax2.contour(10 ** xx, yy, density, levels=levels)

        for c1, c2 in zip(ax1.collections, ax2.collections):
            assert len(get_contour_coords(c1)) == len(get_contour_coords(c2))
            for arr1, arr2 in zip(get_contour_coords(c1), get_contour_coords(c2)):
                assert_array_equal(arr1, arr2)

    def test_bandwidth(self, rng):

        n = 100
        x, y = rng.multivariate_normal([0, 0], [(.2, .5), (.5, 2)], n).T

        f, (ax1, ax2) = plt.subplots(ncols=2)

        kdeplot(x=x, y=y, ax=ax1)
        kdeplot(x=x, y=y, bw_adjust=2, ax=ax2)

        for c1, c2 in zip(ax1.collections, ax2.collections):
            seg1, seg2 = get_contour_coords(c1), get_contour_coords(c2)
            if seg1 + seg2:
                x1 = seg1[0][:, 0]
                x2 = seg2[0][:, 0]
                assert np.abs(x2).max() > np.abs(x1).max()

    def test_weights(self, rng):

        n = 100
        x, y = rng.multivariate_normal([1, 3], [(.2, .5), (.5, 2)], n).T
        hue = np.repeat([0, 1], n // 2)
        weights = rng.uniform(0, 1, n)

        f, (ax1, ax2) = plt.subplots(ncols=2)
        kdeplot(x=x, y=y, hue=hue, ax=ax1)
        kdeplot(x=x, y=y, hue=hue, weights=weights, ax=ax2)

        for c1, c2 in zip(ax1.collections, ax2.collections):
            if get_contour_coords(c1) and get_contour_coords(c2):
                seg1 = np.concatenate(get_contour_coords(c1), axis=0)
                seg2 = np.concatenate(get_contour_coords(c2), axis=0)
                assert not np.array_equal(seg1, seg2)

    def test_hue_ignores_cmap(self, long_df):

        with pytest.warns(UserWarning, match="cmap parameter ignored"):
            ax = kdeplot(data=long_df, x="x", y="y", hue="c", cmap="viridis")

        assert_colors_equal(get_contour_color(ax.collections[0]), "C0")

    def test_contour_line_colors(self, long_df):

        color = (.2, .9, .8, 1)
        ax = kdeplot(data=long_df, x="x", y="y", color=color)

        for c in ax.collections:
            assert_colors_equal(get_contour_color(c), color)

    def test_contour_line_cmap(self, long_df):

        color_list = color_palette("Blues", 12)
        cmap = mpl.colors.ListedColormap(color_list)
        ax = kdeplot(data=long_df, x="x", y="y", cmap=cmap)
        for c in ax.collections:
            for color in get_contour_color(c):
                assert to_rgb(color) in color_list

    def test_contour_fill_colors(self, long_df):

        n = 6
        color = (.2, .9, .8, 1)
        ax = kdeplot(
            data=long_df, x="x", y="y", fill=True, color=color, levels=n,
        )

        cmap = light_palette(color, reverse=True, as_cmap=True)
        lut = cmap(np.linspace(0, 1, 256))
        for c in ax.collections:
            for color in c.get_facecolor():
                assert color in lut

    def test_colorbar(self, long_df):

        ax = kdeplot(data=long_df, x="x", y="y", fill=True, cbar=True)
        assert len(ax.figure.axes) == 2

    def test_levels_and_thresh(self, long_df):

        f, (ax1, ax2) = plt.subplots(ncols=2)

        n = 8
        thresh = .1
        plot_kws = dict(data=long_df, x="x", y="y")
        kdeplot(**plot_kws, levels=n, thresh=thresh, ax=ax1)
        kdeplot(**plot_kws, levels=np.linspace(thresh, 1, n), ax=ax2)

        for c1, c2 in zip(ax1.collections, ax2.collections):
            assert len(get_contour_coords(c1)) == len(get_contour_coords(c2))
            for arr1, arr2 in zip(get_contour_coords(c1), get_contour_coords(c2)):
                assert_array_equal(arr1, arr2)

        with pytest.raises(ValueError):
            kdeplot(**plot_kws, levels=[0, 1, 2])

        ax1.clear()
        ax2.clear()

        kdeplot(**plot_kws, levels=n, thresh=None, ax=ax1)
        kdeplot(**plot_kws, levels=n, thresh=0, ax=ax2)

        for c1, c2 in zip(ax1.collections, ax2.collections):
            assert len(get_contour_coords(c1)) == len(get_contour_coords(c2))
            for arr1, arr2 in zip(get_contour_coords(c1), get_contour_coords(c2)):
                assert_array_equal(arr1, arr2)

        for c1, c2 in zip(ax1.collections, ax2.collections):
            assert_array_equal(c1.get_facecolors(), c2.get_facecolors())

    def test_quantile_to_level(self, rng):

        x = rng.uniform(0, 1, 100000)
        isoprop = np.linspace(.1, 1, 6)

        levels = _DistributionPlotter()._quantile_to_level(x, isoprop)
        for h, p in zip(levels, isoprop):
            assert (x[x <= h].sum() / x.sum()) == pytest.approx(p, abs=1e-4)

    def test_input_checking(self, long_df):

        with pytest.raises(TypeError, match="The x variable is categorical,"):
            kdeplot(data=long_df, x="a", y="y")


class TestHistPlotUnivariate(SharedAxesLevelTests):

    func = staticmethod(histplot)

    def get_last_color(self, ax, element="bars", fill=True):

        if element == "bars":
            if fill:
                return ax.patches[-1].get_facecolor()
            else:
                return ax.patches[-1].get_edgecolor()
        else:
            if fill:
                artist = ax.collections[-1]
                facecolor = artist.get_facecolor()
                edgecolor = artist.get_edgecolor()
                assert_colors_equal(facecolor, edgecolor, check_alpha=False)
                return facecolor
            else:
                return ax.lines[-1].get_color()

    @pytest.mark.parametrize(
        "element,fill",
        itertools.product(["bars", "step", "poly"], [True, False]),
    )
    def test_color(self, long_df, element, fill):

        super().test_color(long_df, element=element, fill=fill)

    @pytest.mark.parametrize(
        "variable", ["x", "y"],
    )
    def test_long_vectors(self, long_df, variable):

        vector = long_df[variable]
        vectors = [
            variable, vector, vector.to_numpy(), vector.to_list(),
        ]

        f, axs = plt.subplots(3)
        for vector, ax in zip(vectors, axs):
            histplot(data=long_df, ax=ax, **{variable: vector})

        bars = [ax.patches for ax in axs]
        for a_bars, b_bars in itertools.product(bars, bars):
            for a, b in zip(a_bars, b_bars):
                assert_array_equal(a.get_height(), b.get_height())
                assert_array_equal(a.get_xy(), b.get_xy())

    def test_wide_vs_long_data(self, wide_df):

        f, (ax1, ax2) = plt.subplots(2)

        histplot(data=wide_df, ax=ax1, common_bins=False)

        for col in wide_df.columns[::-1]:
            histplot(data=wide_df, x=col, ax=ax2)

        for a, b in zip(ax1.patches, ax2.patches):
            assert a.get_height() == b.get_height()
            assert a.get_xy() == b.get_xy()

    def test_flat_vector(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)

        histplot(data=long_df["x"], ax=ax1)
        histplot(data=long_df, x="x", ax=ax2)

        for a, b in zip(ax1.patches, ax2.patches):
            assert a.get_height() == b.get_height()
            assert a.get_xy() == b.get_xy()

    def test_empty_data(self):

        ax = histplot(x=[])
        assert not ax.patches

    def test_variable_assignment(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)

        histplot(data=long_df, x="x", ax=ax1)
        histplot(data=long_df, y="x", ax=ax2)

        for a, b in zip(ax1.patches, ax2.patches):
            assert a.get_height() == b.get_width()

    @pytest.mark.parametrize("element", ["bars", "step", "poly"])
    @pytest.mark.parametrize("multiple", ["layer", "dodge", "stack", "fill"])
    def test_hue_fill_colors(self, long_df, multiple, element):

        ax = histplot(
            data=long_df, x="x", hue="a",
            multiple=multiple, bins=1,
            fill=True, element=element, legend=False,
        )

        palette = color_palette()

        if multiple == "layer":
            if element == "bars":
                a = .5
            else:
                a = .25
        else:
            a = .75

        for bar, color in zip(ax.patches[::-1], palette):
            assert_colors_equal(bar.get_facecolor(), to_rgba(color, a))

        for poly, color in zip(ax.collections[::-1], palette):
            assert_colors_equal(poly.get_facecolor(), to_rgba(color, a))

    def test_hue_stack(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)

        n = 10

        kws = dict(data=long_df, x="x", hue="a", bins=n, element="bars")

        histplot(**kws, multiple="layer", ax=ax1)
        histplot(**kws, multiple="stack", ax=ax2)

        layer_heights = np.reshape([b.get_height() for b in ax1.patches], (-1, n))
        stack_heights = np.reshape([b.get_height() for b in ax2.patches], (-1, n))
        assert_array_equal(layer_heights, stack_heights)

        stack_xys = np.reshape([b.get_xy() for b in ax2.patches], (-1, n, 2))
        assert_array_equal(
            stack_xys[..., 1] + stack_heights,
            stack_heights.cumsum(axis=0),
        )

    def test_hue_fill(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)

        n = 10

        kws = dict(data=long_df, x="x", hue="a", bins=n, element="bars")

        histplot(**kws, multiple="layer", ax=ax1)
        histplot(**kws, multiple="fill", ax=ax2)

        layer_heights = np.reshape([b.get_height() for b in ax1.patches], (-1, n))
        stack_heights = np.reshape([b.get_height() for b in ax2.patches], (-1, n))
        assert_array_almost_equal(
            layer_heights / layer_heights.sum(axis=0), stack_heights
        )

        stack_xys = np.reshape([b.get_xy() for b in ax2.patches], (-1, n, 2))
        assert_array_almost_equal(
            (stack_xys[..., 1] + stack_heights) / stack_heights.sum(axis=0),
            stack_heights.cumsum(axis=0),
        )

    def test_hue_dodge(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)

        bw = 2

        kws = dict(data=long_df, x="x", hue="c", binwidth=bw, element="bars")

        histplot(**kws, multiple="layer", ax=ax1)
        histplot(**kws, multiple="dodge", ax=ax2)

        layer_heights = [b.get_height() for b in ax1.patches]
        dodge_heights = [b.get_height() for b in ax2.patches]
        assert_array_equal(layer_heights, dodge_heights)

        layer_xs = np.reshape([b.get_x() for b in ax1.patches], (2, -1))
        dodge_xs = np.reshape([b.get_x() for b in ax2.patches], (2, -1))
        assert_array_almost_equal(layer_xs[1], dodge_xs[1])
        assert_array_almost_equal(layer_xs[0], dodge_xs[0] - bw / 2)

    def test_hue_as_numpy_dodged(self, long_df):
        # https://github.com/mwaskom/seaborn/issues/2452

        ax = histplot(
            long_df,
            x="y", hue=long_df["a"].to_numpy(),
            multiple="dodge", bins=1,
        )
        # Note hue order reversal
        assert ax.patches[1].get_x() < ax.patches[0].get_x()

    def test_multiple_input_check(self, flat_series):

        with pytest.raises(ValueError, match="`multiple` must be"):
            histplot(flat_series, multiple="invalid")

    def test_element_input_check(self, flat_series):

        with pytest.raises(ValueError, match="`element` must be"):
            histplot(flat_series, element="invalid")

    def test_count_stat(self, flat_series):

        ax = histplot(flat_series, stat="count")
        bar_heights = [b.get_height() for b in ax.patches]
        assert sum(bar_heights) == len(flat_series)

    def test_density_stat(self, flat_series):

        ax = histplot(flat_series, stat="density")
        bar_heights = [b.get_height() for b in ax.patches]
        bar_widths = [b.get_width() for b in ax.patches]
        assert np.multiply(bar_heights, bar_widths).sum() == pytest.approx(1)

    def test_density_stat_common_norm(self, long_df):

        ax = histplot(
            data=long_df, x="x", hue="a",
            stat="density", common_norm=True, element="bars",
        )
        bar_heights = [b.get_height() for b in ax.patches]
        bar_widths = [b.get_width() for b in ax.patches]
        assert np.multiply(bar_heights, bar_widths).sum() == pytest.approx(1)

    def test_density_stat_unique_norm(self, long_df):

        n = 10
        ax = histplot(
            data=long_df, x="x", hue="a",
            stat="density", bins=n, common_norm=False, element="bars",
        )

        bar_groups = ax.patches[:n], ax.patches[-n:]

        for bars in bar_groups:
            bar_heights = [b.get_height() for b in bars]
            bar_widths = [b.get_width() for b in bars]
            bar_areas = np.multiply(bar_heights, bar_widths)
            assert bar_areas.sum() == pytest.approx(1)

    @pytest.fixture(params=["probability", "proportion"])
    def height_norm_arg(self, request):
        return request.param

    def test_probability_stat(self, flat_series, height_norm_arg):

        ax = histplot(flat_series, stat=height_norm_arg)
        bar_heights = [b.get_height() for b in ax.patches]
        assert sum(bar_heights) == pytest.approx(1)

    def test_probability_stat_common_norm(self, long_df, height_norm_arg):

        ax = histplot(
            data=long_df, x="x", hue="a",
            stat=height_norm_arg, common_norm=True, element="bars",
        )
        bar_heights = [b.get_height() for b in ax.patches]
        assert sum(bar_heights) == pytest.approx(1)

    def test_probability_stat_unique_norm(self, long_df, height_norm_arg):

        n = 10
        ax = histplot(
            data=long_df, x="x", hue="a",
            stat=height_norm_arg, bins=n, common_norm=False, element="bars",
        )

        bar_groups = ax.patches[:n], ax.patches[-n:]

        for bars in bar_groups:
            bar_heights = [b.get_height() for b in bars]
            assert sum(bar_heights) == pytest.approx(1)

    def test_percent_stat(self, flat_series):

        ax = histplot(flat_series, stat="percent")
        bar_heights = [b.get_height() for b in ax.patches]
        assert sum(bar_heights) == 100

    def test_common_bins(self, long_df):

        n = 10
        ax = histplot(
            long_df, x="x", hue="a", common_bins=True, bins=n, element="bars",
        )

        bar_groups = ax.patches[:n], ax.patches[-n:]
        assert_array_equal(
            [b.get_xy() for b in bar_groups[0]],
            [b.get_xy() for b in bar_groups[1]]
        )

    def test_unique_bins(self, wide_df):

        ax = histplot(wide_df, common_bins=False, bins=10, element="bars")

        bar_groups = np.split(np.array(ax.patches), len(wide_df.columns))

        for i, col in enumerate(wide_df.columns[::-1]):
            bars = bar_groups[i]
            start = bars[0].get_x()
            stop = bars[-1].get_x() + bars[-1].get_width()
            assert_array_almost_equal(start, wide_df[col].min())
            assert_array_almost_equal(stop, wide_df[col].max())

    def test_range_with_inf(self, rng):

        x = rng.normal(0, 1, 20)
        ax = histplot([-np.inf, *x])
        leftmost_edge = min(p.get_x() for p in ax.patches)
        assert leftmost_edge == x.min()

    def test_weights_with_missing(self, null_df):

        ax = histplot(null_df, x="x", weights="s", bins=5)

        bar_heights = [bar.get_height() for bar in ax.patches]
        total_weight = null_df[["x", "s"]].dropna()["s"].sum()
        assert sum(bar_heights) == pytest.approx(total_weight)

    def test_weight_norm(self, rng):

        vals = rng.normal(0, 1, 50)
        x = np.concatenate([vals, vals])
        w = np.repeat([1, 2], 50)
        ax = histplot(
            x=x, weights=w, hue=w, common_norm=True, stat="density", bins=5
        )

        # Recall that artists are added in reverse of hue order
        y1 = [bar.get_height() for bar in ax.patches[:5]]
        y2 = [bar.get_height() for bar in ax.patches[5:]]

        assert sum(y1) == 2 * sum(y2)

    def test_discrete(self, long_df):

        ax = histplot(long_df, x="s", discrete=True)

        data_min = long_df["s"].min()
        data_max = long_df["s"].max()
        assert len(ax.patches) == (data_max - data_min + 1)

        for i, bar in enumerate(ax.patches):
            assert bar.get_width() == 1
            assert bar.get_x() == (data_min + i - .5)

    def test_discrete_categorical_default(self, long_df):

        ax = histplot(long_df, x="a")
        for i, bar in enumerate(ax.patches):
            assert bar.get_width() == 1

    def test_categorical_yaxis_inversion(self, long_df):

        ax = histplot(long_df, y="a")
        ymax, ymin = ax.get_ylim()
        assert ymax > ymin

    def test_datetime_scale(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)
        histplot(x=long_df["t"], fill=True, ax=ax1)
        histplot(x=long_df["t"], fill=False, ax=ax2)
        assert ax1.get_xlim() == ax2.get_xlim()

    @pytest.mark.parametrize("stat", ["count", "density", "probability"])
    def test_kde(self, flat_series, stat):

        ax = histplot(
            flat_series, kde=True, stat=stat, kde_kws={"cut": 10}
        )

        bar_widths = [b.get_width() for b in ax.patches]
        bar_heights = [b.get_height() for b in ax.patches]
        hist_area = np.multiply(bar_widths, bar_heights).sum()

        density, = ax.lines
        kde_area = integrate(density.get_ydata(), density.get_xdata())

        assert kde_area == pytest.approx(hist_area)

    @pytest.mark.parametrize("multiple", ["layer", "dodge"])
    @pytest.mark.parametrize("stat", ["count", "density", "probability"])
    def test_kde_with_hue(self, long_df, stat, multiple):

        n = 10
        ax = histplot(
            long_df, x="x", hue="c", multiple=multiple,
            kde=True, stat=stat, element="bars",
            kde_kws={"cut": 10}, bins=n,
        )

        bar_groups = ax.patches[:n], ax.patches[-n:]

        for i, bars in enumerate(bar_groups):
            bar_widths = [b.get_width() for b in bars]
            bar_heights = [b.get_height() for b in bars]
            hist_area = np.multiply(bar_widths, bar_heights).sum()

            x, y = ax.lines[i].get_xydata().T
            kde_area = integrate(y, x)

            if multiple == "layer":
                assert kde_area == pytest.approx(hist_area)
            elif multiple == "dodge":
                assert kde_area == pytest.approx(hist_area * 2)

    def test_kde_default_cut(self, flat_series):

        ax = histplot(flat_series, kde=True)
        support = ax.lines[0].get_xdata()
        assert support.min() == flat_series.min()
        assert support.max() == flat_series.max()

    def test_kde_hue(self, long_df):

        n = 10
        ax = histplot(data=long_df, x="x", hue="a", kde=True, bins=n)

        for bar, line in zip(ax.patches[::n], ax.lines):
            assert_colors_equal(
                bar.get_facecolor(), line.get_color(), check_alpha=False
            )

    def test_kde_yaxis(self, flat_series):

        f, ax = plt.subplots()
        histplot(x=flat_series, kde=True)
        histplot(y=flat_series, kde=True)

        x, y = ax.lines
        assert_array_equal(x.get_xdata(), y.get_ydata())
        assert_array_equal(x.get_ydata(), y.get_xdata())

    def test_kde_line_kws(self, flat_series):

        lw = 5
        ax = histplot(flat_series, kde=True, line_kws=dict(lw=lw))
        assert ax.lines[0].get_linewidth() == lw

    def test_kde_singular_data(self):

        with warnings.catch_warnings():
            warnings.simplefilter("error")
            ax = histplot(x=np.ones(10), kde=True)
        assert not ax.lines

        with warnings.catch_warnings():
            warnings.simplefilter("error")
            ax = histplot(x=[5], kde=True)
        assert not ax.lines

    def test_element_default(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)
        histplot(long_df, x="x", ax=ax1)
        histplot(long_df, x="x", ax=ax2, element="bars")
        assert len(ax1.patches) == len(ax2.patches)

        f, (ax1, ax2) = plt.subplots(2)
        histplot(long_df, x="x", hue="a", ax=ax1)
        histplot(long_df, x="x", hue="a", ax=ax2, element="bars")
        assert len(ax1.patches) == len(ax2.patches)

    def test_bars_no_fill(self, flat_series):

        alpha = .5
        ax = histplot(flat_series, element="bars", fill=False, alpha=alpha)
        for bar in ax.patches:
            assert bar.get_facecolor() == (0, 0, 0, 0)
            assert bar.get_edgecolor()[-1] == alpha

    def test_step_fill(self, flat_series):

        f, (ax1, ax2) = plt.subplots(2)

        n = 10
        histplot(flat_series, element="bars", fill=True, bins=n, ax=ax1)
        histplot(flat_series, element="step", fill=True, bins=n, ax=ax2)

        bar_heights = [b.get_height() for b in ax1.patches]
        bar_widths = [b.get_width() for b in ax1.patches]
        bar_edges = [b.get_x() for b in ax1.patches]

        fill = ax2.collections[0]
        x, y = fill.get_paths()[0].vertices[::-1].T

        assert_array_equal(x[1:2 * n:2], bar_edges)
        assert_array_equal(y[1:2 * n:2], bar_heights)

        assert x[n * 2] == bar_edges[-1] + bar_widths[-1]
        assert y[n * 2] == bar_heights[-1]

    def test_poly_fill(self, flat_series):

        f, (ax1, ax2) = plt.subplots(2)

        n = 10
        histplot(flat_series, element="bars", fill=True, bins=n, ax=ax1)
        histplot(flat_series, element="poly", fill=True, bins=n, ax=ax2)

        bar_heights = np.array([b.get_height() for b in ax1.patches])
        bar_widths = np.array([b.get_width() for b in ax1.patches])
        bar_edges = np.array([b.get_x() for b in ax1.patches])

        fill = ax2.collections[0]
        x, y = fill.get_paths()[0].vertices[::-1].T

        assert_array_equal(x[1:n + 1], bar_edges + bar_widths / 2)
        assert_array_equal(y[1:n + 1], bar_heights)

    def test_poly_no_fill(self, flat_series):

        f, (ax1, ax2) = plt.subplots(2)

        n = 10
        histplot(flat_series, element="bars", fill=False, bins=n, ax=ax1)
        histplot(flat_series, element="poly", fill=False, bins=n, ax=ax2)

        bar_heights = np.array([b.get_height() for b in ax1.patches])
        bar_widths = np.array([b.get_width() for b in ax1.patches])
        bar_edges = np.array([b.get_x() for b in ax1.patches])

        x, y = ax2.lines[0].get_xydata().T

        assert_array_equal(x, bar_edges + bar_widths / 2)
        assert_array_equal(y, bar_heights)

    def test_step_no_fill(self, flat_series):

        f, (ax1, ax2) = plt.subplots(2)

        histplot(flat_series, element="bars", fill=False, ax=ax1)
        histplot(flat_series, element="step", fill=False, ax=ax2)

        bar_heights = [b.get_height() for b in ax1.patches]
        bar_widths = [b.get_width() for b in ax1.patches]
        bar_edges = [b.get_x() for b in ax1.patches]

        x, y = ax2.lines[0].get_xydata().T

        assert_array_equal(x[:-1], bar_edges)
        assert_array_equal(y[:-1], bar_heights)
        assert x[-1] == bar_edges[-1] + bar_widths[-1]
        assert y[-1] == y[-2]

    def test_step_fill_xy(self, flat_series):

        f, ax = plt.subplots()

        histplot(x=flat_series, element="step", fill=True)
        histplot(y=flat_series, element="step", fill=True)

        xverts = ax.collections[0].get_paths()[0].vertices
        yverts = ax.collections[1].get_paths()[0].vertices

        assert_array_equal(xverts, yverts[:, ::-1])

    def test_step_no_fill_xy(self, flat_series):

        f, ax = plt.subplots()

        histplot(x=flat_series, element="step", fill=False)
        histplot(y=flat_series, element="step", fill=False)

        xline, yline = ax.lines

        assert_array_equal(xline.get_xdata(), yline.get_ydata())
        assert_array_equal(xline.get_ydata(), yline.get_xdata())

    def test_weighted_histogram(self):

        ax = histplot(x=[0, 1, 2], weights=[1, 2, 3], discrete=True)

        bar_heights = [b.get_height() for b in ax.patches]
        assert bar_heights == [1, 2, 3]

    def test_weights_with_auto_bins(self, long_df):

        with pytest.warns(UserWarning):
            ax = histplot(long_df, x="x", weights="f")
        assert len(ax.patches) == 10

    def test_shrink(self, long_df):

        f, (ax1, ax2) = plt.subplots(2)

        bw = 2
        shrink = .4

        histplot(long_df, x="x", binwidth=bw, ax=ax1)
        histplot(long_df, x="x", binwidth=bw, shrink=shrink, ax=ax2)

        for p1, p2 in zip(ax1.patches, ax2.patches):

            w1, w2 = p1.get_width(), p2.get_width()
            assert w2 == pytest.approx(shrink * w1)

            x1, x2 = p1.get_x(), p2.get_x()
            assert (x2 + w2 / 2) == pytest.approx(x1 + w1 / 2)

    def test_log_scale_explicit(self, rng):

        x = rng.lognormal(0, 2, 1000)
        ax = histplot(x, log_scale=True, binrange=(-3, 3), binwidth=1)

        bar_widths = [b.get_width() for b in ax.patches]
        steps = np.divide(bar_widths[1:], bar_widths[:-1])
        assert np.allclose(steps, 10)

    def test_log_scale_implicit(self, rng):

        x = rng.lognormal(0, 2, 1000)

        f, ax = plt.subplots()
        ax.set_xscale("log")
        histplot(x, binrange=(-3, 3), binwidth=1, ax=ax)

        bar_widths = [b.get_width() for b in ax.patches]
        steps = np.divide(bar_widths[1:], bar_widths[:-1])
        assert np.allclose(steps, 10)

    def test_log_scale_dodge(self, rng):

        x = rng.lognormal(0, 2, 100)
        hue = np.repeat(["a", "b"], 50)
        ax = histplot(x=x, hue=hue, bins=5, log_scale=True, multiple="dodge")
        x_min = np.log([b.get_x() for b in ax.patches])
        x_max = np.log([b.get_x() + b.get_width() for b in ax.patches])
        assert np.unique(np.round(x_max - x_min, 10)).size == 1

    def test_log_scale_kde(self, rng):

        x = rng.lognormal(0, 1, 1000)
        ax = histplot(x=x, log_scale=True, kde=True, bins=20)
        bar_height = max(p.get_height() for p in ax.patches)
        kde_height = max(ax.lines[0].get_ydata())
        assert bar_height == pytest.approx(kde_height, rel=.1)

    @pytest.mark.parametrize(
        "fill", [True, False],
    )
    def test_auto_linewidth(self, flat_series, fill):

        get_lw = lambda ax: ax.patches[0].get_linewidth()  # noqa: E731

        kws = dict(element="bars", fill=fill)

        f, (ax1, ax2) = plt.subplots(2)
        histplot(flat_series, **kws, bins=10, ax=ax1)
        histplot(flat_series, **kws, bins=100, ax=ax2)
        assert get_lw(ax1) > get_lw(ax2)

        f, ax1 = plt.subplots(figsize=(10, 5))
        f, ax2 = plt.subplots(figsize=(2, 5))
        histplot(flat_series, **kws, bins=30, ax=ax1)
        histplot(flat_series, **kws, bins=30, ax=ax2)
        assert get_lw(ax1) > get_lw(ax2)

        f, ax1 = plt.subplots(figsize=(4, 5))
        f, ax2 = plt.subplots(figsize=(4, 5))
        histplot(flat_series, **kws, bins=30, ax=ax1)
        histplot(10 ** flat_series, **kws, bins=30, log_scale=True, ax=ax2)
        assert get_lw(ax1) == pytest.approx(get_lw(ax2))

        f, ax1 = plt.subplots(figsize=(4, 5))
        f, ax2 = plt.subplots(figsize=(4, 5))
        histplot(y=[0, 1, 1], **kws, discrete=True, ax=ax1)
        histplot(y=["a", "b", "b"], **kws, ax=ax2)
        assert get_lw(ax1) == pytest.approx(get_lw(ax2))

    def test_bar_kwargs(self, flat_series):

        lw = 2
        ec = (1, .2, .9, .5)
        ax = histplot(flat_series, binwidth=1, ec=ec, lw=lw)
        for bar in ax.patches:
            assert_colors_equal(bar.get_edgecolor(), ec)
            assert bar.get_linewidth() == lw

    def test_step_fill_kwargs(self, flat_series):

        lw = 2
        ec = (1, .2, .9, .5)
        ax = histplot(flat_series, element="step", ec=ec, lw=lw)
        poly = ax.collections[0]
        assert_colors_equal(poly.get_edgecolor(), ec)
        assert poly.get_linewidth() == lw

    def test_step_line_kwargs(self, flat_series):

        lw = 2
        ls = "--"
        ax = histplot(flat_series, element="step", fill=False, lw=lw, ls=ls)
        line = ax.lines[0]
        assert line.get_linewidth() == lw
        assert line.get_linestyle() == ls

    def test_label(self, flat_series):

        ax = histplot(flat_series, label="a label")
        handles, labels = ax.get_legend_handles_labels()
        assert len(handles) == 1
        assert labels == ["a label"]

    def test_default_color_scout_cleanup(self, flat_series):

        ax = histplot(flat_series)
        assert len(ax.containers) == 1


class TestHistPlotBivariate:

    def test_mesh(self, long_df):

        hist = Histogram()
        counts, (x_edges, y_edges) = hist(long_df["x"], long_df["y"])

        ax = histplot(long_df, x="x", y="y")
        mesh = ax.collections[0]
        mesh_data = mesh.get_array()

        assert_array_equal(mesh_data.data.flat, counts.T.flat)
        assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)

        edges = itertools.product(y_edges[:-1], x_edges[:-1])
        for i, (y, x) in enumerate(edges):
            path = mesh.get_paths()[i]
            assert path.vertices[0, 0] == x
            assert path.vertices[0, 1] == y

    def test_mesh_with_hue(self, long_df):

        ax = histplot(long_df, x="x", y="y", hue="c")

        hist = Histogram()
        hist.define_bin_params(long_df["x"], long_df["y"])

        for i, sub_df in long_df.groupby("c"):

            mesh = ax.collections[i]
            mesh_data = mesh.get_array()

            counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])

            assert_array_equal(mesh_data.data.flat, counts.T.flat)
            assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)

            edges = itertools.product(y_edges[:-1], x_edges[:-1])
            for i, (y, x) in enumerate(edges):
                path = mesh.get_paths()[i]
                assert path.vertices[0, 0] == x
                assert path.vertices[0, 1] == y

    def test_mesh_with_hue_unique_bins(self, long_df):

        ax = histplot(long_df, x="x", y="y", hue="c", common_bins=False)

        for i, sub_df in long_df.groupby("c"):

            hist = Histogram()

            mesh = ax.collections[i]
            mesh_data = mesh.get_array()

            counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])

            assert_array_equal(mesh_data.data.flat, counts.T.flat)
            assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)

            edges = itertools.product(y_edges[:-1], x_edges[:-1])
            for i, (y, x) in enumerate(edges):
                path = mesh.get_paths()[i]
                assert path.vertices[0, 0] == x
                assert path.vertices[0, 1] == y

    def test_mesh_with_col_unique_bins(self, long_df):

        g = displot(long_df, x="x", y="y", col="c", common_bins=False)

        for i, sub_df in long_df.groupby("c"):

            hist = Histogram()

            mesh = g.axes.flat[i].collections[0]
            mesh_data = mesh.get_array()

            counts, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])

            assert_array_equal(mesh_data.data.flat, counts.T.flat)
            assert_array_equal(mesh_data.mask.flat, counts.T.flat == 0)

            edges = itertools.product(y_edges[:-1], x_edges[:-1])
            for i, (y, x) in enumerate(edges):
                path = mesh.get_paths()[i]
                assert path.vertices[0, 0] == x
                assert path.vertices[0, 1] == y

    def test_mesh_log_scale(self, rng):

        x, y = rng.lognormal(0, 1, (2, 1000))
        hist = Histogram()
        counts, (x_edges, y_edges) = hist(np.log10(x), np.log10(y))

        ax = histplot(x=x, y=y, log_scale=True)
        mesh = ax.collections[0]
        mesh_data = mesh.get_array()

        assert_array_equal(mesh_data.data.flat, counts.T.flat)

        edges = itertools.product(y_edges[:-1], x_edges[:-1])
        for i, (y_i, x_i) in enumerate(edges):
            path = mesh.get_paths()[i]
            assert path.vertices[0, 0] == pytest.approx(10 ** x_i)
            assert path.vertices[0, 1] == pytest.approx(10 ** y_i)

    def test_mesh_thresh(self, long_df):

        hist = Histogram()
        counts, (x_edges, y_edges) = hist(long_df["x"], long_df["y"])

        thresh = 5
        ax = histplot(long_df, x="x", y="y", thresh=thresh)
        mesh = ax.collections[0]
        mesh_data = mesh.get_array()

        assert_array_equal(mesh_data.data.flat, counts.T.flat)
        assert_array_equal(mesh_data.mask.flat, (counts <= thresh).T.flat)

    def test_mesh_sticky_edges(self, long_df):

        ax = histplot(long_df, x="x", y="y", thresh=None)
        mesh = ax.collections[0]
        assert mesh.sticky_edges.x == [long_df["x"].min(), long_df["x"].max()]
        assert mesh.sticky_edges.y == [long_df["y"].min(), long_df["y"].max()]

        ax.clear()
        ax = histplot(long_df, x="x", y="y")
        mesh = ax.collections[0]
        assert not mesh.sticky_edges.x
        assert not mesh.sticky_edges.y

    def test_mesh_common_norm(self, long_df):

        stat = "density"
        ax = histplot(
            long_df, x="x", y="y", hue="c", common_norm=True, stat=stat,
        )

        hist = Histogram(stat="density")
        hist.define_bin_params(long_df["x"], long_df["y"])

        for i, sub_df in long_df.groupby("c"):

            mesh = ax.collections[i]
            mesh_data = mesh.get_array()

            density, (x_edges, y_edges) = hist(sub_df["x"], sub_df["y"])

            scale = len(sub_df) / len(long_df)
            assert_array_equal(mesh_data.data.flat, (density * scale).T.flat)

    def test_mesh_unique_norm(self, long_df):

        stat = "density"
        ax = histplot(
            long_df, x="x", y="y", hue="c", common_norm=False, stat=stat,
        )

        hist = Histogram()
        bin_kws = hist.define_bin_params(long_df["x"], long_df["y"])

        for i, sub_df in long_df.groupby("c"):

            sub_hist = Histogram(bins=bin_kws["bins"], stat=stat)

            mesh = ax.collections[i]
            mesh_data = mesh.get_array()

            density, (x_edges, y_edges) = sub_hist(sub_df["x"], sub_df["y"])
            assert_array_equal(mesh_data.data.flat, density.T.flat)

    @pytest.mark.parametrize("stat", ["probability", "proportion", "percent"])
    def test_mesh_normalization(self, long_df, stat):

        ax = histplot(
            long_df, x="x", y="y", stat=stat,
        )

        mesh_data = ax.collections[0].get_array()
        expected_sum = {"percent": 100}.get(stat, 1)
        assert mesh_data.data.sum() == expected_sum

    def test_mesh_colors(self, long_df):

        color = "r"
        f, ax = plt.subplots()
        histplot(
            long_df, x="x", y="y", color=color,
        )
        mesh = ax.collections[0]
        assert_array_equal(
            mesh.get_cmap().colors,
            _DistributionPlotter()._cmap_from_color(color).colors,
        )

        f, ax = plt.subplots()
        histplot(
            long_df, x="x", y="y", hue="c",
        )
        colors = color_palette()
        for i, mesh in enumerate(ax.collections):
            assert_array_equal(
                mesh.get_cmap().colors,
                _DistributionPlotter()._cmap_from_color(colors[i]).colors,
            )

    def test_color_limits(self, long_df):

        f, (ax1, ax2, ax3) = plt.subplots(3)
        kws = dict(data=long_df, x="x", y="y")
        hist = Histogram()
        counts, _ = hist(long_df["x"], long_df["y"])

        histplot(**kws, ax=ax1)
        assert ax1.collections[0].get_clim() == (0, counts.max())

        vmax = 10
        histplot(**kws, vmax=vmax, ax=ax2)
        counts, _ = hist(long_df["x"], long_df["y"])
        assert ax2.collections[0].get_clim() == (0, vmax)

        pmax = .8
        pthresh = .1
        f = _DistributionPlotter()._quantile_to_level

        histplot(**kws, pmax=pmax, pthresh=pthresh, ax=ax3)
        counts, _ = hist(long_df["x"], long_df["y"])
        mesh = ax3.collections[0]
        assert mesh.get_clim() == (0, f(counts, pmax))
        assert_array_equal(
            mesh.get_array().mask.flat,
            (counts <= f(counts, pthresh)).T.flat,
        )

    def test_hue_color_limits(self, long_df):

        _, (ax1, ax2, ax3, ax4) = plt.subplots(4)
        kws = dict(data=long_df, x="x", y="y", hue="c", bins=4)

        hist = Histogram(bins=kws["bins"])
        hist.define_bin_params(long_df["x"], long_df["y"])
        full_counts, _ = hist(long_df["x"], long_df["y"])

        sub_counts = []
        for _, sub_df in long_df.groupby(kws["hue"]):
            c, _ = hist(sub_df["x"], sub_df["y"])
            sub_counts.append(c)

        pmax = .8
        pthresh = .05
        f = _DistributionPlotter()._quantile_to_level

        histplot(**kws, common_norm=True, ax=ax1)
        for i, mesh in enumerate(ax1.collections):
            assert mesh.get_clim() == (0, full_counts.max())

        histplot(**kws, common_norm=False, ax=ax2)
        for i, mesh in enumerate(ax2.collections):
            assert mesh.get_clim() == (0, sub_counts[i].max())

        histplot(**kws, common_norm=True, pmax=pmax, pthresh=pthresh, ax=ax3)
        for i, mesh in enumerate(ax3.collections):
            assert mesh.get_clim() == (0, f(full_counts, pmax))
            assert_array_equal(
                mesh.get_array().mask.flat,
                (sub_counts[i] <= f(full_counts, pthresh)).T.flat,
            )

        histplot(**kws, common_norm=False, pmax=pmax, pthresh=pthresh, ax=ax4)
        for i, mesh in enumerate(ax4.collections):
            assert mesh.get_clim() == (0, f(sub_counts[i], pmax))
            assert_array_equal(
                mesh.get_array().mask.flat,
                (sub_counts[i] <= f(sub_counts[i], pthresh)).T.flat,
            )

    def test_colorbar(self, long_df):

        f, ax = plt.subplots()
        histplot(long_df, x="x", y="y", cbar=True, ax=ax)
        assert len(ax.figure.axes) == 2

        f, (ax, cax) = plt.subplots(2)
        histplot(long_df, x="x", y="y", cbar=True, cbar_ax=cax, ax=ax)
        assert len(ax.figure.axes) == 2


class TestECDFPlotUnivariate(SharedAxesLevelTests):

    func = staticmethod(ecdfplot)

    def get_last_color(self, ax):

        return to_rgb(ax.lines[-1].get_color())

    @pytest.mark.parametrize("variable", ["x", "y"])
    def test_long_vectors(self, long_df, variable):

        vector = long_df[variable]
        vectors = [
            variable, vector, vector.to_numpy(), vector.to_list(),
        ]

        f, ax = plt.subplots()
        for vector in vectors:
            ecdfplot(data=long_df, ax=ax, **{variable: vector})

        xdata = [l.get_xdata() for l in ax.lines]
        for a, b in itertools.product(xdata, xdata):
            assert_array_equal(a, b)

        ydata = [l.get_ydata() for l in ax.lines]
        for a, b in itertools.product(ydata, ydata):
            assert_array_equal(a, b)

    def test_hue(self, long_df):

        ax = ecdfplot(long_df, x="x", hue="a")

        for line, color in zip(ax.lines[::-1], color_palette()):
            assert_colors_equal(line.get_color(), color)

    def test_line_kwargs(self, long_df):

        color = "r"
        ls = "--"
        lw = 3
        ax = ecdfplot(long_df, x="x", color=color, ls=ls, lw=lw)

        for line in ax.lines:
            assert_colors_equal(line.get_color(), color)
            assert line.get_linestyle() == ls
            assert line.get_linewidth() == lw

    @pytest.mark.parametrize("data_var", ["x", "y"])
    def test_drawstyle(self, flat_series, data_var):

        ax = ecdfplot(**{data_var: flat_series})
        drawstyles = dict(x="steps-post", y="steps-pre")
        assert ax.lines[0].get_drawstyle() == drawstyles[data_var]

    @pytest.mark.parametrize(
        "data_var,stat_var", [["x", "y"], ["y", "x"]],
    )
    def test_proportion_limits(self, flat_series, data_var, stat_var):

        ax = ecdfplot(**{data_var: flat_series})
        data = getattr(ax.lines[0], f"get_{stat_var}data")()
        assert data[0] == 0
        assert data[-1] == 1
        sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var)
        assert sticky_edges[:] == [0, 1]

    @pytest.mark.parametrize(
        "data_var,stat_var", [["x", "y"], ["y", "x"]],
    )
    def test_proportion_limits_complementary(self, flat_series, data_var, stat_var):

        ax = ecdfplot(**{data_var: flat_series}, complementary=True)
        data = getattr(ax.lines[0], f"get_{stat_var}data")()
        assert data[0] == 1
        assert data[-1] == 0
        sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var)
        assert sticky_edges[:] == [0, 1]

    @pytest.mark.parametrize(
        "data_var,stat_var", [["x", "y"], ["y", "x"]],
    )
    def test_proportion_count(self, flat_series, data_var, stat_var):

        n = len(flat_series)
        ax = ecdfplot(**{data_var: flat_series}, stat="count")
        data = getattr(ax.lines[0], f"get_{stat_var}data")()
        assert data[0] == 0
        assert data[-1] == n
        sticky_edges = getattr(ax.lines[0].sticky_edges, stat_var)
        assert sticky_edges[:] == [0, n]

    def test_weights(self):

        ax = ecdfplot(x=[1, 2, 3], weights=[1, 1, 2])
        y = ax.lines[0].get_ydata()
        assert_array_equal(y, [0, .25, .5, 1])

    def test_bivariate_error(self, long_df):

        with pytest.raises(NotImplementedError, match="Bivariate ECDF plots"):
            ecdfplot(data=long_df, x="x", y="y")

    def test_log_scale(self, long_df):

        ax1, ax2 = plt.figure().subplots(2)

        ecdfplot(data=long_df, x="z", ax=ax1)
        ecdfplot(data=long_df, x="z", log_scale=True, ax=ax2)

        # Ignore first point, which either -inf (in linear) or 0 (in log)
        line1 = ax1.lines[0].get_xydata()[1:]
        line2 = ax2.lines[0].get_xydata()[1:]

        assert_array_almost_equal(line1, line2)


class TestDisPlot:

    # TODO probably good to move these utility attributes/methods somewhere else
    @pytest.mark.parametrize(
        "kwargs", [
            dict(),
            dict(x="x"),
            dict(x="t"),
            dict(x="a"),
            dict(x="z", log_scale=True),
            dict(x="x", binwidth=4),
            dict(x="x", weights="f", bins=5),
            dict(x="x", color="green", linewidth=2, binwidth=4),
            dict(x="x", hue="a", fill=False),
            dict(x="y", hue="a", fill=False),
            dict(x="x", hue="a", multiple="stack"),
            dict(x="x", hue="a", element="step"),
            dict(x="x", hue="a", palette="muted"),
            dict(x="x", hue="a", kde=True),
            dict(x="x", hue="a", stat="density", common_norm=False),
            dict(x="x", y="y"),
        ],
    )
    def test_versus_single_histplot(self, long_df, kwargs):

        ax = histplot(long_df, **kwargs)
        g = displot(long_df, **kwargs)
        assert_plots_equal(ax, g.ax)

        if ax.legend_ is not None:
            assert_legends_equal(ax.legend_, g._legend)

        if kwargs:
            long_df["_"] = "_"
            g2 = displot(long_df, col="_", **kwargs)
            assert_plots_equal(ax, g2.ax)

    @pytest.mark.parametrize(
        "kwargs", [
            dict(),
            dict(x="x"),
            dict(x="t"),
            dict(x="z", log_scale=True),
            dict(x="x", bw_adjust=.5),
            dict(x="x", weights="f"),
            dict(x="x", color="green", linewidth=2),
            dict(x="x", hue="a", multiple="stack"),
            dict(x="x", hue="a", fill=True),
            dict(x="y", hue="a", fill=False),
            dict(x="x", hue="a", palette="muted"),
            dict(x="x", y="y"),
        ],
    )
    def test_versus_single_kdeplot(self, long_df, kwargs):

        ax = kdeplot(data=long_df, **kwargs)
        g = displot(long_df, kind="kde", **kwargs)
        assert_plots_equal(ax, g.ax)

        if ax.legend_ is not None:
            assert_legends_equal(ax.legend_, g._legend)

        if kwargs:
            long_df["_"] = "_"
            g2 = displot(long_df, kind="kde", col="_", **kwargs)
            assert_plots_equal(ax, g2.ax)

    @pytest.mark.parametrize(
        "kwargs", [
            dict(),
            dict(x="x"),
            dict(x="t"),
            dict(x="z", log_scale=True),
            dict(x="x", weights="f"),
            dict(y="x"),
            dict(x="x", color="green", linewidth=2),
            dict(x="x", hue="a", complementary=True),
            dict(x="x", hue="a", stat="count"),
            dict(x="x", hue="a", palette="muted"),
        ],
    )
    def test_versus_single_ecdfplot(self, long_df, kwargs):

        ax = ecdfplot(data=long_df, **kwargs)
        g = displot(long_df, kind="ecdf", **kwargs)
        assert_plots_equal(ax, g.ax)

        if ax.legend_ is not None:
            assert_legends_equal(ax.legend_, g._legend)

        if kwargs:
            long_df["_"] = "_"
            g2 = displot(long_df, kind="ecdf", col="_", **kwargs)
            assert_plots_equal(ax, g2.ax)

    @pytest.mark.parametrize(
        "kwargs", [
            dict(x="x"),
            dict(x="x", y="y"),
            dict(x="x", hue="a"),
        ]
    )
    def test_with_rug(self, long_df, kwargs):

        ax = plt.figure().subplots()
        histplot(data=long_df, **kwargs, ax=ax)
        rugplot(data=long_df, **kwargs, ax=ax)

        g = displot(long_df, rug=True, **kwargs)

        assert_plots_equal(ax, g.ax, labels=False)

        long_df["_"] = "_"
        g2 = displot(long_df, col="_", rug=True, **kwargs)

        assert_plots_equal(ax, g2.ax, labels=False)

    @pytest.mark.parametrize(
        "facet_var", ["col", "row"],
    )
    def test_facets(self, long_df, facet_var):

        kwargs = {facet_var: "a"}
        ax = kdeplot(data=long_df, x="x", hue="a")
        g = displot(long_df, x="x", kind="kde", **kwargs)

        legend_texts = ax.legend_.get_texts()

        for i, line in enumerate(ax.lines[::-1]):
            facet_ax = g.axes.flat[i]
            facet_line = facet_ax.lines[0]
            assert_array_equal(line.get_xydata(), facet_line.get_xydata())

            text = legend_texts[i].get_text()
            assert text in facet_ax.get_title()

    @pytest.mark.parametrize("multiple", ["dodge", "stack", "fill"])
    def test_facet_multiple(self, long_df, multiple):

        bins = np.linspace(0, 20, 5)
        ax = histplot(
            data=long_df[long_df["c"] == 0],
            x="x", hue="a", hue_order=["a", "b", "c"],
            multiple=multiple, bins=bins,
        )

        g = displot(
            data=long_df, x="x", hue="a", col="c", hue_order=["a", "b", "c"],
            multiple=multiple, bins=bins,
        )

        assert_plots_equal(ax, g.axes_dict[0])

    def test_ax_warning(self, long_df):

        ax = plt.figure().subplots()
        with pytest.warns(UserWarning, match="`displot` is a figure-level"):
            displot(long_df, x="x", ax=ax)

    @pytest.mark.parametrize("key", ["col", "row"])
    def test_array_faceting(self, long_df, key):

        a = long_df["a"].to_numpy()
        vals = categorical_order(a)
        g = displot(long_df, x="x", **{key: a})
        assert len(g.axes.flat) == len(vals)
        for ax, val in zip(g.axes.flat, vals):
            assert val in ax.get_title()

    def test_legend(self, long_df):

        g = displot(long_df, x="x", hue="a")
        assert g._legend is not None

    def test_empty(self):

        g = displot(x=[], y=[])
        assert isinstance(g, FacetGrid)

    def test_bivariate_ecdf_error(self, long_df):

        with pytest.raises(NotImplementedError):
            displot(long_df, x="x", y="y", kind="ecdf")

    def test_bivariate_kde_norm(self, rng):

        x, y = rng.normal(0, 1, (2, 100))
        z = [0] * 80 + [1] * 20

        def count_contours(ax):
            if _version_predates(mpl, "3.8.0rc1"):
                return sum(bool(get_contour_coords(c)) for c in ax.collections)
            else:
                return sum(bool(p.vertices.size) for p in ax.collections[0].get_paths())

        g = displot(x=x, y=y, col=z, kind="kde", levels=10)
        l1 = count_contours(g.axes.flat[0])
        l2 = count_contours(g.axes.flat[1])
        assert l1 > l2

        g = displot(x=x, y=y, col=z, kind="kde", levels=10, common_norm=False)
        l1 = count_contours(g.axes.flat[0])
        l2 = count_contours(g.axes.flat[1])
        assert l1 == l2

    def test_bivariate_hist_norm(self, rng):

        x, y = rng.normal(0, 1, (2, 100))
        z = [0] * 80 + [1] * 20

        g = displot(x=x, y=y, col=z, kind="hist")
        clim1 = g.axes.flat[0].collections[0].get_clim()
        clim2 = g.axes.flat[1].collections[0].get_clim()
        assert clim1 == clim2

        g = displot(x=x, y=y, col=z, kind="hist", common_norm=False)
        clim1 = g.axes.flat[0].collections[0].get_clim()
        clim2 = g.axes.flat[1].collections[0].get_clim()
        assert clim1[1] > clim2[1]

    def test_facetgrid_data(self, long_df):

        g = displot(
            data=long_df.to_dict(orient="list"),
            x="z",
            hue=long_df["a"].rename("hue_var"),
            col=long_df["c"].to_numpy(),
        )
        expected_cols = set(long_df.columns.to_list() + ["hue_var", "_col_"])
        assert set(g.data.columns) == expected_cols
        assert_array_equal(g.data["hue_var"], long_df["a"])
        assert_array_equal(g.data["_col_"], long_df["c"])


def integrate(y, x):
    """"Simple numerical integration for testing KDE code."""
    y = np.asarray(y)
    x = np.asarray(x)
    dx = np.diff(x)
    return (dx * y[:-1] + dx * y[1:]).sum() / 2


================================================
FILE: tests/test_docstrings.py
================================================
from seaborn._docstrings import DocstringComponents


EXAMPLE_DICT = dict(
    param_a="""
a : str
    The first parameter.
    """,
)


class ExampleClass:
    def example_method(self):
        """An example method.

        Parameters
        ----------
        a : str
           A method parameter.

        """


def example_func():
    """An example function.

    Parameters
    ----------
    a : str
        A function parameter.

    """


class TestDocstringComponents:

    def test_from_dict(self):

        obj = DocstringComponents(EXAMPLE_DICT)
        assert obj.param_a == "a : str\n    The first parameter."

    def test_from_nested_components(self):

        obj_inner = DocstringComponents(EXAMPLE_DICT)
        obj_outer = DocstringComponents.from_nested_components(inner=obj_inner)
        assert obj_outer.inner.param_a == "a : str\n    The first parameter."

    def test_from_function(self):

        obj = DocstringComponents.from_function_params(example_func)
        assert obj.a == "a : str\n    A function parameter."

    def test_from_method(self):

        obj = DocstringComponents.from_function_params(
            ExampleClass.example_method
        )
        assert obj.a == "a : str\n    A method parameter."


================================================
FILE: tests/test_matrix.py
================================================
import tempfile
import copy

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import pandas as pd

try:
    from scipy.spatial import distance
    from scipy.cluster import hierarchy
    _no_scipy = False
except ImportError:
    _no_scipy = True

try:
    import fastcluster
    assert fastcluster
    _no_fastcluster = False
except ImportError:
    _no_fastcluster = True

import numpy.testing as npt
import pandas.testing as pdt
import pytest

from seaborn import matrix as mat
from seaborn import color_palette
from seaborn._compat import get_colormap
from seaborn._testing import assert_colors_equal


class TestHeatmap:
    rs = np.random.RandomState(sum(map(ord, "heatmap")))

    x_norm = rs.randn(4, 8)
    letters = pd.Series(["A", "B", "C", "D"], name="letters")
    df_norm = pd.DataFrame(x_norm, index=letters)

    x_unif = rs.rand(20, 13)
    df_unif = pd.DataFrame(x_unif)

    default_kws = dict(vmin=None, vmax=None, cmap=None, center=None,
                       robust=False, annot=False, fmt=".2f", annot_kws=None,
                       cbar=True, cbar_kws=None, mask=None)

    def test_ndarray_input(self):

        p = mat._HeatMapper(self.x_norm, **self.default_kws)
        npt.assert_array_equal(p.plot_data, self.x_norm)
        pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm))

        npt.assert_array_equal(p.xticklabels, np.arange(8))
        npt.assert_array_equal(p.yticklabels, np.arange(4))

        assert p.xlabel == ""
        assert p.ylabel == ""

    def test_array_like_input(self):
        class ArrayLike:
            def __init__(self, data):
                self.data = data

            def __array__(self, **kwargs):
                return np.asarray(self.data, **kwargs)

        p = mat._HeatMapper(ArrayLike(self.x_norm), **self.default_kws)
        npt.assert_array_equal(p.plot_data, self.x_norm)
        pdt.assert_frame_equal(p.data, pd.DataFrame(self.x_norm))

        npt.assert_array_equal(p.xticklabels, np.arange(8))
        npt.assert_array_equal(p.yticklabels, np.arange(4))

        assert p.xlabel == ""
        assert p.ylabel == ""

    def test_df_input(self):

        p = mat._HeatMapper(self.df_norm, **self.default_kws)
        npt.assert_array_equal(p.plot_data, self.x_norm)
        pdt.assert_frame_equal(p.data, self.df_norm)

        npt.assert_array_equal(p.xticklabels, np.arange(8))
        npt.assert_array_equal(p.yticklabels, self.letters.values)

        assert p.xlabel == ""
        assert p.ylabel == "letters"

    def test_df_multindex_input(self):

        df = self.df_norm.copy()
        index = pd.MultiIndex.from_tuples([("A", 1), ("B", 2),
                                           ("C", 3), ("D", 4)],
                                          names=["letter", "number"])
        index.name = "letter-number"
        df.index = index

        p = mat._HeatMapper(df, **self.default_kws)

        combined_tick_labels = ["A-1", "B-2", "C-3", "D-4"]
        npt.assert_array_equal(p.yticklabels, combined_tick_labels)
        assert p.ylabel == "letter-number"

        p = mat._HeatMapper(df.T, **self.default_kws)

        npt.assert_array_equal(p.xticklabels, combined_tick_labels)
        assert p.xlabel == "letter-number"

    @pytest.mark.parametrize("dtype", [float, np.int64, object])
    def test_mask_input(self, dtype):
        kws = self.default_kws.copy()

        mask = self.x_norm > 0
        kws['mask'] = mask
        data = self.x_norm.astype(dtype)
        p = mat._HeatMapper(data, **kws)
        plot_data = np.ma.masked_where(mask, data)

        npt.assert_array_equal(p.plot_data, plot_data)

    def test_mask_limits(self):
        """Make sure masked cells are not used to calculate extremes"""

        kws = self.default_kws.copy()

        mask = self.x_norm > 0
        kws['mask'] = mask
        p = mat._HeatMapper(self.x_norm, **kws)

        assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()
        assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()

        mask = self.x_norm < 0
        kws['mask'] = mask
        p = mat._HeatMapper(self.x_norm, **kws)

        assert p.vmin == np.ma.array(self.x_norm, mask=mask).min()
        assert p.vmax == np.ma.array(self.x_norm, mask=mask).max()

    def test_default_vlims(self):

        p = mat._HeatMapper(self.df_unif, **self.default_kws)
        assert p.vmin == self.x_unif.min()
        assert p.vmax == self.x_unif.max()

    def test_robust_vlims(self):

        kws = self.default_kws.copy()
        kws["robust"] = True
        p = mat._HeatMapper(self.df_unif, **kws)

        assert p.vmin == np.percentile(self.x_unif, 2)
        assert p.vmax == np.percentile(self.x_unif, 98)

    def test_custom_sequential_vlims(self):

        kws = self.default_kws.copy()
        kws["vmin"] = 0
        kws["vmax"] = 1
        p = mat._HeatMapper(self.df_unif, **kws)

        assert p.vmin == 0
        assert p.vmax == 1

    def test_custom_diverging_vlims(self):

        kws = self.default_kws.copy()
        kws["vmin"] = -4
        kws["vmax"] = 5
        kws["center"] = 0
        p = mat._HeatMapper(self.df_norm, **kws)

        assert p.vmin == -4
        assert p.vmax == 5

    def test_array_with_nans(self):

        x1 = self.rs.rand(10, 10)
        nulls = np.zeros(10) * np.nan
        x2 = np.c_[x1, nulls]

        m1 = mat._HeatMapper(x1, **self.default_kws)
        m2 = mat._HeatMapper(x2, **self.default_kws)

        assert m1.vmin == m2.vmin
        assert m1.vmax == m2.vmax

    def test_mask(self):

        df = pd.DataFrame(data={'a': [1, 1, 1],
                                'b': [2, np.nan, 2],
                                'c': [3, 3, np.nan]})

        kws = self.default_kws.copy()
        kws["mask"] = np.isnan(df.values)

        m = mat._HeatMapper(df, **kws)

        npt.assert_array_equal(np.isnan(m.plot_data.data),
                               m.plot_data.mask)

    def test_custom_cmap(self):

        kws = self.default_kws.copy()
        kws["cmap"] = "BuGn"
        p = mat._HeatMapper(self.df_unif, **kws)
        assert p.cmap == mpl.cm.BuGn

    def test_centered_vlims(self):

        kws = self.default_kws.copy()
        kws["center"] = .5

        p = mat._HeatMapper(self.df_unif, **kws)

        assert p.vmin == self.df_unif.values.min()
        assert p.vmax == self.df_unif.values.max()

    def test_default_colors(self):

        vals = np.linspace(.2, 1, 9)
        cmap = mpl.cm.binary
        ax = mat.heatmap([vals], cmap=cmap)
        fc = ax.collections[0].get_facecolors()
        cvals = np.linspace(0, 1, 9)
        npt.assert_array_almost_equal(fc, cmap(cvals), 2)

    def test_custom_vlim_colors(self):

        vals = np.linspace(.2, 1, 9)
        cmap = mpl.cm.binary
        ax = mat.heatmap([vals], vmin=0, cmap=cmap)
        fc = ax.collections[0].get_facecolors()
        npt.assert_array_almost_equal(fc, cmap(vals), 2)

    def test_custom_center_colors(self):

        vals = np.linspace(.2, 1, 9)
        cmap = mpl.cm.binary
        ax = mat.heatmap([vals], center=.5, cmap=cmap)
        fc = ax.collections[0].get_facecolors()
        npt.assert_array_almost_equal(fc, cmap(vals), 2)

    def test_cmap_with_properties(self):

        kws = self.default_kws.copy()
        cmap = copy.copy(get_colormap("BrBG"))
        cmap.set_bad("red")
        kws["cmap"] = cmap
        hm = mat._HeatMapper(self.df_unif, **kws)
        npt.assert_array_equal(
            cmap(np.ma.masked_invalid([np.nan])),
            hm.cmap(np.ma.masked_invalid([np.nan])))

        kws["center"] = 0.5
        hm = mat._HeatMapper(self.df_unif, **kws)
        npt.assert_array_equal(
            cmap(np.ma.masked_invalid([np.nan])),
            hm.cmap(np.ma.masked_invalid([np.nan])))

        kws = self.default_kws.copy()
        cmap = copy.copy(get_colormap("BrBG"))
        cmap.set_under("red")
        kws["cmap"] = cmap
        hm = mat._HeatMapper(self.df_unif, **kws)
        npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))

        kws["center"] = .5
        hm = mat._HeatMapper(self.df_unif, **kws)
        npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))

        kws = self.default_kws.copy()
        cmap = copy.copy(get_colormap("BrBG"))
        cmap.set_over("red")
        kws["cmap"] = cmap
        hm = mat._HeatMapper(self.df_unif, **kws)
        npt.assert_array_equal(cmap(-np.inf), hm.cmap(-np.inf))

        kws["center"] = .5
        hm = mat._HeatMapper(self.df_unif, **kws)
        npt.assert_array_equal(cmap(np.inf), hm.cmap(np.inf))

    def test_explicit_none_norm(self):

        vals = np.linspace(.2, 1, 9)
        cmap = mpl.cm.binary
        _, (ax1, ax2) = plt.subplots(2)

        mat.heatmap([vals], vmin=0, cmap=cmap, ax=ax1)
        fc_default_norm = ax1.collections[0].get_facecolors()

        mat.heatmap([vals], vmin=0, norm=None, cmap=cmap, ax=ax2)
        fc_explicit_norm = ax2.collections[0].get_facecolors()

        npt.assert_array_almost_equal(fc_default_norm, fc_explicit_norm, 2)

    def test_ticklabels_off(self):
        kws = self.default_kws.copy()
        kws['xticklabels'] = False
        kws['yticklabels'] = False
        p = mat._HeatMapper(self.df_norm, **kws)
        assert p.xticklabels == []
        assert p.yticklabels == []

    def test_custom_ticklabels(self):
        kws = self.default_kws.copy()
        xticklabels = list('iheartheatmaps'[:self.df_norm.shape[1]])
        yticklabels = list('heatmapsarecool'[:self.df_norm.shape[0]])
        kws['xticklabels'] = xticklabels
        kws['yticklabels'] = yticklabels
        p = mat._HeatMapper(self.df_norm, **kws)
        assert p.xticklabels == xticklabels
        assert p.yticklabels == yticklabels

    def test_custom_ticklabel_interval(self):

        kws = self.default_kws.copy()
        xstep, ystep = 2, 3
        kws['xticklabels'] = xstep
        kws['yticklabels'] = ystep
        p = mat._HeatMapper(self.df_norm, **kws)

        nx, ny = self.df_norm.T.shape
        npt.assert_array_equal(p.xticks, np.arange(0, nx, xstep) + .5)
        npt.assert_array_equal(p.yticks, np.arange(0, ny, ystep) + .5)
        npt.assert_array_equal(p.xticklabels,
                               self.df_norm.columns[0:nx:xstep])
        npt.assert_array_equal(p.yticklabels,
                               self.df_norm.index[0:ny:ystep])

    def test_heatmap_annotation(self):

        ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f",
                         annot_kws={"fontsize": 14})
        for val, text in zip(self.x_norm.flat, ax.texts):
            assert text.get_text() == f"{val:.1f}"
            assert text.get_fontsize() == 14

    def test_heatmap_annotation_overwrite_kws(self):

        annot_kws = dict(color="0.3", va="bottom", ha="left")
        ax = mat.heatmap(self.df_norm, annot=True, fmt=".1f",
                         annot_kws=annot_kws)
        for text in ax.texts:
            assert text.get_color() == "0.3"
            assert text.get_ha() == "left"
            assert text.get_va() == "bottom"

    def test_heatmap_annotation_with_mask(self):

        df = pd.DataFrame(data={'a': [1, 1, 1],
                                'b': [2, np.nan, 2],
                                'c': [3, 3, np.nan]})
        mask = np.isnan(df.values)
        df_masked = np.ma.masked_where(mask, df)
        ax = mat.heatmap(df, annot=True, fmt='.1f', mask=mask)
        assert len(df_masked.compressed()) == len(ax.texts)
        for val, text in zip(df_masked.compressed(), ax.texts):
            assert f"{val:.1f}" == text.get_text()

    def test_heatmap_annotation_mesh_colors(self):

        ax = mat.heatmap(self.df_norm, annot=True)
        mesh = ax.collections[0]
        assert len(mesh.get_facecolors()) == self.df_norm.values.size

        plt.close("all")

    def test_heatmap_annotation_other_data(self):
        annot_data = self.df_norm + 10

        ax = mat.heatmap(self.df_norm, annot=annot_data, fmt=".1f",
                         annot_kws={"fontsize": 14})

        for val, text in zip(annot_data.values.flat, ax.texts):
            assert text.get_text() == f"{val:.1f}"
            assert text.get_fontsize() == 14

    def test_heatmap_annotation_different_shapes(self):

        annot_data = self.df_norm.iloc[:-1]
        with pytest.raises(ValueError):
            mat.heatmap(self.df_norm, annot=annot_data)

    def test_heatmap_annotation_with_limited_ticklabels(self):
        ax = mat.heatmap(self.df_norm, fmt=".2f", annot=True,
                         xticklabels=False, yticklabels=False)
        for val, text in zip(self.x_norm.flat, ax.texts):
            assert text.get_text() == f"{val:.2f}"

    def test_heatmap_cbar(self):

        f = plt.figure()
        mat.heatmap(self.df_norm)
        assert len(f.axes) == 2
        plt.close(f)

        f = plt.figure()
        mat.heatmap(self.df_norm, cbar=False)
        assert len(f.axes) == 1
        plt.close(f)

        f, (ax1, ax2) = plt.subplots(2)
        mat.heatmap(self.df_norm, ax=ax1, cbar_ax=ax2)
        assert len(f.axes) == 2
        plt.close(f)

    def test_heatmap_axes(self):

        ax = mat.heatmap(self.df_norm)

        xtl = [int(l.get_text()) for l in ax.get_xticklabels()]
        assert xtl == list(self.df_norm.columns)
        ytl = [l.get_text() for l in ax.get_yticklabels()]
        assert ytl == list(self.df_norm.index)

        assert ax.get_xlabel() == ""
        assert ax.get_ylabel() == "letters"

        assert ax.get_xlim() == (0, 8)
        assert ax.get_ylim() == (4, 0)

    def test_heatmap_ticklabel_rotation(self):

        f, ax = plt.subplots(figsize=(2, 2))
        mat.heatmap(self.df_norm, xticklabels=1, yticklabels=1, ax=ax)

        for t in ax.get_xticklabels():
            assert t.get_rotation() == 0

        for t in ax.get_yticklabels():
            assert t.get_rotation() == 90

        plt.close(f)

        df = self.df_norm.copy()
        df.columns = [str(c) * 10 for c in df.columns]
        df.index = [i * 10 for i in df.index]

        f, ax = plt.subplots(figsize=(2, 2))
        mat.heatmap(df, xticklabels=1, yticklabels=1, ax=ax)

        for t in ax.get_xticklabels():
            assert t.get_rotation() == 90

        for t in ax.get_yticklabels():
            assert t.get_rotation() == 0

        plt.close(f)

    def test_heatmap_inner_lines(self):

        c = (0, 0, 1, 1)
        ax = mat.heatmap(self.df_norm, linewidths=2, linecolor=c)
        mesh = ax.collections[0]
        assert mesh.get_linewidths()[0] == 2
        assert tuple(mesh.get_edgecolor()[0]) == c

    def test_square_aspect(self):

        ax = mat.heatmap(self.df_norm, square=True)
        npt.assert_equal(ax.get_aspect(), 1)

    def test_mask_validation(self):

        mask = mat._matrix_mask(self.df_norm, None)
        assert mask.shape == self.df_norm.shape
        assert mask.values.sum() == 0

        with pytest.raises(ValueError):
            bad_array_mask = self.rs.randn(3, 6) > 0
            mat._matrix_mask(self.df_norm, bad_array_mask)

        with pytest.raises(ValueError):
            bad_df_mask = pd.DataFrame(self.rs.randn(4, 8) > 0)
            mat._matrix_mask(self.df_norm, bad_df_mask)

    def test_missing_data_mask(self):

        data = pd.DataFrame(np.arange(4, dtype=float).reshape(2, 2))
        data.loc[0, 0] = np.nan
        mask = mat._matrix_mask(data, None)
        npt.assert_array_equal(mask, [[True, False], [False, False]])

        mask_in = np.array([[False, True], [False, False]])
        mask_out = mat._matrix_mask(data, mask_in)
        npt.assert_array_equal(mask_out, [[True, True], [False, False]])

    def test_cbar_ticks(self):

        f, (ax1, ax2) = plt.subplots(2)
        mat.heatmap(self.df_norm, ax=ax1, cbar_ax=ax2,
                    cbar_kws=dict(drawedges=True))
        assert len(ax2.collections) == 2


@pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
class TestDendrogram:

    rs = np.random.RandomState(sum(map(ord, "dendrogram")))

    default_kws = dict(linkage=None, metric='euclidean', method='single',
                       axis=1, label=True, rotate=False)

    x_norm = rs.randn(4, 8) + np.arange(8)
    x_norm = (x_norm.T + np.arange(4)).T
    letters = pd.Series(["A", "B", "C", "D", "E", "F", "G", "H"],
                        name="letters")

    df_norm = pd.DataFrame(x_norm, columns=letters)

    if not _no_scipy:
        if _no_fastcluster:
            x_norm_distances = distance.pdist(x_norm.T, metric='euclidean')
            x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single')
        else:
            x_norm_linkage = fastcluster.linkage_vector(x_norm.T,
                                                        metric='euclidean',
                                                        method='single')

        x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True,
                                                 color_threshold=-np.inf)
        x_norm_leaves = x_norm_dendrogram['leaves']
        df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves])

    def test_ndarray_input(self):
        p = mat._DendrogramPlotter(self.x_norm, **self.default_kws)
        npt.assert_array_equal(p.array.T, self.x_norm)
        pdt.assert_frame_equal(p.data.T, pd.DataFrame(self.x_norm))

        npt.assert_array_equal(p.linkage, self.x_norm_linkage)
        assert p.dendrogram == self.x_norm_dendrogram

        npt.assert_array_equal(p.reordered_ind, self.x_norm_leaves)

        npt.assert_array_equal(p.xticklabels, self.x_norm_leaves)
        npt.assert_array_equal(p.yticklabels, [])

        assert p.xlabel is None
        assert p.ylabel == ''

    def test_df_input(self):
        p = mat._DendrogramPlotter(self.df_norm, **self.default_kws)
        npt.assert_array_equal(p.array.T, np.asarray(self.df_norm))
        pdt.assert_frame_equal(p.data.T, self.df_norm)

        npt.assert_array_equal(p.linkage, self.x_norm_linkage)
        assert p.dendrogram == self.x_norm_dendrogram

        npt.assert_array_equal(p.xticklabels,
                               np.asarray(self.df_norm.columns)[
                                   self.x_norm_leaves])
        npt.assert_array_equal(p.yticklabels, [])

        assert p.xlabel == 'letters'
        assert p.ylabel == ''

    def test_df_multindex_input(self):

        df = self.df_norm.copy()
        index = pd.MultiIndex.from_tuples([("A", 1), ("B", 2),
                                           ("C", 3), ("D", 4)],
                                          names=["letter", "number"])
        index.name = "letter-number"
        df.index = index
        kws = self.default_kws.copy()
        kws['label'] = True

        p = mat._DendrogramPlotter(df.T, **kws)

        xticklabels = ["A-1", "B-2", "C-3", "D-4"]
        xticklabels = [xticklabels[i] for i in p.reordered_ind]
        npt.assert_array_equal(p.xticklabels, xticklabels)
        npt.assert_array_equal(p.yticklabels, [])
        assert p.xlabel == "letter-number"

    def test_axis0_input(self):
        kws = self.default_kws.copy()
        kws['axis'] = 0
        p = mat._DendrogramPlotter(self.df_norm.T, **kws)

        npt.assert_array_equal(p.array, np.asarray(self.df_norm.T))
        pdt.assert_frame_equal(p.data, self.df_norm.T)

        npt.assert_array_equal(p.linkage, self.x_norm_linkage)
        assert p.dendrogram == self.x_norm_dendrogram

        npt.assert_array_equal(p.xticklabels, self.df_norm_leaves)
        npt.assert_array_equal(p.yticklabels, [])

        assert p.xlabel == 'letters'
        assert p.ylabel == ''

    def test_rotate_input(self):
        kws = self.default_kws.copy()
        kws['rotate'] = True
        p = mat._DendrogramPlotter(self.df_norm, **kws)
        npt.assert_array_equal(p.array.T, np.asarray(self.df_norm))
        pdt.assert_frame_equal(p.data.T, self.df_norm)

        npt.assert_array_equal(p.xticklabels, [])
        npt.assert_array_equal(p.yticklabels, self.df_norm_leaves)

        assert p.xlabel == ''
        assert p.ylabel == 'letters'

    def test_rotate_axis0_input(self):
        kws = self.default_kws.copy()
        kws['rotate'] = True
        kws['axis'] = 0
        p = mat._DendrogramPlotter(self.df_norm.T, **kws)

        npt.assert_array_equal(p.reordered_ind, self.x_norm_leaves)

    def test_custom_linkage(self):
        kws = self.default_kws.copy()

        try:
            import fastcluster

            linkage = fastcluster.linkage_vector(self.x_norm, method='single',
                                                 metric='euclidean')
        except ImportError:
            d = distance.pdist(self.x_norm, metric='euclidean')
            linkage = hierarchy.linkage(d, method='single')
        dendrogram = hierarchy.dendrogram(linkage, no_plot=True,
                                          color_threshold=-np.inf)
        kws['linkage'] = linkage
        p = mat._DendrogramPlotter(self.df_norm, **kws)

        npt.assert_array_equal(p.linkage, linkage)
        assert p.dendrogram == dendrogram

    def test_label_false(self):
        kws = self.default_kws.copy()
        kws['label'] = False
        p = mat._DendrogramPlotter(self.df_norm, **kws)
        assert p.xticks == []
        assert p.yticks == []
        assert p.xticklabels == []
        assert p.yticklabels == []
        assert p.xlabel == ""
        assert p.ylabel == ""

    def test_linkage_scipy(self):
        p = mat._DendrogramPlotter(self.x_norm, **self.default_kws)

        scipy_linkage = p._calculate_linkage_scipy()

        from scipy.spatial import distance
        from scipy.cluster import hierarchy

        dists = distance.pdist(self.x_norm.T,
                               metric=self.default_kws['metric'])
        linkage = hierarchy.linkage(dists, method=self.default_kws['method'])

        npt.assert_array_equal(scipy_linkage, linkage)

    @pytest.mark.skipif(_no_fastcluster, reason="fastcluster not installed")
    def test_fastcluster_other_method(self):
        import fastcluster

        kws = self.default_kws.copy()
        kws['method'] = 'average'
        linkage = fastcluster.linkage(self.x_norm.T, method='average',
                                      metric='euclidean')
        p = mat._DendrogramPlotter(self.x_norm, **kws)
        npt.assert_array_equal(p.linkage, linkage)

    @pytest.mark.skipif(_no_fastcluster, reason="fastcluster not installed")
    def test_fastcluster_non_euclidean(self):
        import fastcluster

        kws = self.default_kws.copy()
        kws['metric'] = 'cosine'
        kws['method'] = 'average'
        linkage = fastcluster.linkage(self.x_norm.T, method=kws['method'],
                                      metric=kws['metric'])
        p = mat._DendrogramPlotter(self.x_norm, **kws)
        npt.assert_array_equal(p.linkage, linkage)

    def test_dendrogram_plot(self):
        d = mat.dendrogram(self.x_norm, **self.default_kws)

        ax = plt.gca()
        xlim = ax.get_xlim()
        # 10 comes from _plot_dendrogram in scipy.cluster.hierarchy
        xmax = len(d.reordered_ind) * 10

        assert xlim[0] == 0
        assert xlim[1] == xmax

        assert len(ax.collections[0].get_paths()) == len(d.dependent_coord)

    def test_dendrogram_rotate(self):
        kws = self.default_kws.copy()
        kws['rotate'] = True

        d = mat.dendrogram(self.x_norm, **kws)

        ax = plt.gca()
        ylim = ax.get_ylim()

        # 10 comes from _plot_dendrogram in scipy.cluster.hierarchy
        ymax = len(d.reordered_ind) * 10

        # Since y axis is inverted, ylim is (80, 0)
        # and therefore not (0, 80) as usual:
        assert ylim[1] == 0
        assert ylim[0] == ymax

    def test_dendrogram_ticklabel_rotation(self):
        f, ax = plt.subplots(figsize=(2, 2))
        mat.dendrogram(self.df_norm, ax=ax)

        for t in ax.get_xticklabels():
            assert t.get_rotation() == 0

        plt.close(f)

        df = self.df_norm.copy()
        df.columns = [str(c) * 10 for c in df.columns]
        df.index = [i * 10 for i in df.index]

        f, ax = plt.subplots(figsize=(2, 2))
        mat.dendrogram(df, ax=ax)

        for t in ax.get_xticklabels():
            assert t.get_rotation() == 90

        plt.close(f)

        f, ax = plt.subplots(figsize=(2, 2))
        mat.dendrogram(df.T, axis=0, rotate=True)
        for t in ax.get_yticklabels():
            assert t.get_rotation() == 0
        plt.close(f)


@pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
class TestClustermap:

    rs = np.random.RandomState(sum(map(ord, "clustermap")))

    x_norm = rs.randn(4, 8) + np.arange(8)
    x_norm = (x_norm.T + np.arange(4)).T
    letters = pd.Series(["A", "B", "C", "D", "E", "F", "G", "H"],
                        name="letters")

    df_norm = pd.DataFrame(x_norm, columns=letters)

    default_kws = dict(pivot_kws=None, z_score=None, standard_scale=None,
                       figsize=(10, 10), row_colors=None, col_colors=None,
                       dendrogram_ratio=.2, colors_ratio=.03,
                       cbar_pos=(0, .8, .05, .2))

    default_plot_kws = dict(metric='euclidean', method='average',
                            colorbar_kws=None,
                            row_cluster=True, col_cluster=True,
                            row_linkage=None, col_linkage=None,
                            tree_kws=None)

    row_colors = color_palette('Set2', df_norm.shape[0])
    col_colors = color_palette('Dark2', df_norm.shape[1])

    if not _no_scipy:
        if _no_fastcluster:
            x_norm_distances = distance.pdist(x_norm.T, metric='euclidean')
            x_norm_linkage = hierarchy.linkage(x_norm_distances, method='single')
        else:
            x_norm_linkage = fastcluster.linkage_vector(x_norm.T,
                                                        metric='euclidean',
                                                        method='single')

        x_norm_dendrogram = hierarchy.dendrogram(x_norm_linkage, no_plot=True,
                                                 color_threshold=-np.inf)
        x_norm_leaves = x_norm_dendrogram['leaves']
        df_norm_leaves = np.asarray(df_norm.columns[x_norm_leaves])

    def test_ndarray_input(self):
        cg = mat.ClusterGrid(self.x_norm, **self.default_kws)
        pdt.assert_frame_equal(cg.data, pd.DataFrame(self.x_norm))
        assert len(cg.fig.axes) == 4
        assert cg.ax_row_colors is None
        assert cg.ax_col_colors is None

    def test_df_input(self):
        cg = mat.ClusterGrid(self.df_norm, **self.default_kws)
        pdt.assert_frame_equal(cg.data, self.df_norm)

    def test_corr_df_input(self):
        df = self.df_norm.corr()
        cg = mat.ClusterGrid(df, **self.default_kws)
        cg.plot(**self.default_plot_kws)
        diag = cg.data2d.values[np.diag_indices_from(cg.data2d)]
        npt.assert_array_almost_equal(diag, np.ones(cg.data2d.shape[0]))

    def test_pivot_input(self):
        df_norm = self.df_norm.copy()
        df_norm.index.name = 'numbers'
        df_long = pd.melt(df_norm.reset_index(), var_name='letters',
                          id_vars='numbers')
        kws = self.default_kws.copy()
        kws['pivot_kws'] = dict(index='numbers', columns='letters',
                                values='value')
        cg = mat.ClusterGrid(df_long, **kws)

        pdt.assert_frame_equal(cg.data2d, df_norm)

    def test_colors_input(self):
        kws = self.default_kws.copy()

        kws['row_colors'] = self.row_colors
        kws['col_colors'] = self.col_colors

        cg = mat.ClusterGrid(self.df_norm, **kws)
        npt.assert_array_equal(cg.row_colors, self.row_colors)
        npt.assert_array_equal(cg.col_colors, self.col_colors)

        assert len(cg.fig.axes) == 6

    def test_categorical_colors_input(self):
        kws = self.default_kws.copy()

        row_colors = pd.Series(self.row_colors, dtype="category")
        col_colors = pd.Series(
            self.col_colors, dtype="category", index=self.df_norm.columns
        )

        kws['row_colors'] = row_colors
        kws['col_colors'] = col_colors

        exp_row_colors = list(map(mpl.colors.to_rgb, row_colors))
        exp_col_colors = list(map(mpl.colors.to_rgb, col_colors))

        cg = mat.ClusterGrid(self.df_norm, **kws)
        npt.assert_array_equal(cg.row_colors, exp_row_colors)
        npt.assert_array_equal(cg.col_colors, exp_col_colors)

        assert len(cg.fig.axes) == 6

    def test_nested_colors_input(self):
        kws = self.default_kws.copy()

        row_colors = [self.row_colors, self.row_colors]
        col_colors = [self.col_colors, self.col_colors]
        kws['row_colors'] = row_colors
        kws['col_colors'] = col_colors

        cm = mat.ClusterGrid(self.df_norm, **kws)
        npt.assert_array_equal(cm.row_colors, row_colors)
        npt.assert_array_equal(cm.col_colors, col_colors)

        assert len(cm.fig.axes) == 6

    def test_colors_input_custom_cmap(self):
        kws = self.default_kws.copy()

        kws['cmap'] = mpl.cm.PRGn
        kws['row_colors'] = self.row_colors
        kws['col_colors'] = self.col_colors

        cg = mat.clustermap(self.df_norm, **kws)
        npt.assert_array_equal(cg.row_colors, self.row_colors)
        npt.assert_array_equal(cg.col_colors, self.col_colors)

        assert len(cg.fig.axes) == 6

    def test_z_score(self):
        df = self.df_norm.copy()
        df = (df - df.mean()) / df.std()
        kws = self.default_kws.copy()
        kws['z_score'] = 1

        cg = mat.ClusterGrid(self.df_norm, **kws)
        pdt.assert_frame_equal(cg.data2d, df)

    def test_z_score_axis0(self):
        df = self.df_norm.copy()
        df = df.T
        df = (df - df.mean()) / df.std()
        df = df.T
        kws = self.default_kws.copy()
        kws['z_score'] = 0

        cg = mat.ClusterGrid(self.df_norm, **kws)
        pdt.assert_frame_equal(cg.data2d, df)

    def test_standard_scale(self):
        df = self.df_norm.copy()
        df = (df - df.min()) / (df.max() - df.min())
        kws = self.default_kws.copy()
        kws['standard_scale'] = 1

        cg = mat.ClusterGrid(self.df_norm, **kws)
        pdt.assert_frame_equal(cg.data2d, df)

    def test_standard_scale_axis0(self):
        df = self.df_norm.copy()
        df = df.T
        df = (df - df.min()) / (df.max() - df.min())
        df = df.T
        kws = self.default_kws.copy()
        kws['standard_scale'] = 0

        cg = mat.ClusterGrid(self.df_norm, **kws)
        pdt.assert_frame_equal(cg.data2d, df)

    def test_z_score_standard_scale(self):
        kws = self.default_kws.copy()
        kws['z_score'] = True
        kws['standard_scale'] = True
        with pytest.raises(ValueError):
            mat.ClusterGrid(self.df_norm, **kws)

    def test_color_list_to_matrix_and_cmap(self):
        # Note this uses the attribute named col_colors but tests row colors
        matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
            self.col_colors, self.x_norm_leaves, axis=0)

        for i, leaf in enumerate(self.x_norm_leaves):
            color = self.col_colors[leaf]
            assert_colors_equal(cmap(matrix[i, 0]), color)

    def test_nested_color_list_to_matrix_and_cmap(self):
        # Note this uses the attribute named col_colors but tests row colors
        colors = [self.col_colors, self.col_colors[::-1]]
        matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
            colors, self.x_norm_leaves, axis=0)

        for i, leaf in enumerate(self.x_norm_leaves):
            for j, color_row in enumerate(colors):
                color = color_row[leaf]
                assert_colors_equal(cmap(matrix[i, j]), color)

    def test_color_list_to_matrix_and_cmap_axis1(self):
        matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
            self.col_colors, self.x_norm_leaves, axis=1)

        for j, leaf in enumerate(self.x_norm_leaves):
            color = self.col_colors[leaf]
            assert_colors_equal(cmap(matrix[0, j]), color)

    def test_color_list_to_matrix_and_cmap_different_sizes(self):
        colors = [self.col_colors, self.col_colors * 2]
        with pytest.raises(ValueError):
            matrix, cmap = mat.ClusterGrid.color_list_to_matrix_and_cmap(
                colors, self.x_norm_leaves, axis=1)

    def test_savefig(self):
        # Not sure if this is the right way to test....
        cg = mat.ClusterGrid(self.df_norm, **self.default_kws)
        cg.plot(**self.default_plot_kws)
        cg.savefig(tempfile.NamedTemporaryFile(), format='png')

    def test_plot_dendrograms(self):
        cm = mat.clustermap(self.df_norm, **self.default_kws)

        assert len(cm.ax_row_dendrogram.collections[0].get_paths()) == len(
            cm.dendrogram_row.independent_coord
        )
        assert len(cm.ax_col_dendrogram.collections[0].get_paths()) == len(
            cm.dendrogram_col.independent_coord
        )
        data2d = self.df_norm.iloc[cm.dendrogram_row.reordered_ind,
                                   cm.dendrogram_col.reordered_ind]
        pdt.assert_frame_equal(cm.data2d, data2d)

    def test_cluster_false(self):
        kws = self.default_kws.copy()
        kws['row_cluster'] = False
        kws['col_cluster'] = False

        cm = mat.clustermap(self.df_norm, **kws)
        assert len(cm.ax_row_dendrogram.lines) == 0
        assert len(cm.ax_col_dendrogram.lines) == 0

        assert len(cm.ax_row_dendrogram.get_xticks()) == 0
        assert len(cm.ax_row_dendrogram.get_yticks()) == 0
        assert len(cm.ax_col_dendrogram.get_xticks()) == 0
        assert len(cm.ax_col_dendrogram.get_yticks()) == 0

        pdt.assert_frame_equal(cm.data2d, self.df_norm)

    def test_row_col_colors(self):
        kws = self.default_kws.copy()
        kws['row_colors'] = self.row_colors
        kws['col_colors'] = self.col_colors

        cm = mat.clustermap(self.df_norm, **kws)

        assert len(cm.ax_row_colors.collections) == 1
        assert len(cm.ax_col_colors.collections) == 1

    def test_cluster_false_row_col_colors(self):
        kws = self.default_kws.copy()
        kws['row_cluster'] = False
        kws['col_cluster'] = False
        kws['row_colors'] = self.row_colors
        kws['col_colors'] = self.col_colors

        cm = mat.clustermap(self.df_norm, **kws)
        assert len(cm.ax_row_dendrogram.lines) == 0
        assert len(cm.ax_col_dendrogram.lines) == 0

        assert len(cm.ax_row_dendrogram.get_xticks()) == 0
        assert len(cm.ax_row_dendrogram.get_yticks()) == 0
        assert len(cm.ax_col_dendrogram.get_xticks()) == 0
        assert len(cm.ax_col_dendrogram.get_yticks()) == 0
        assert len(cm.ax_row_colors.collections) == 1
        assert len(cm.ax_col_colors.collections) == 1

        pdt.assert_frame_equal(cm.data2d, self.df_norm)

    def test_row_col_colors_df(self):
        kws = self.default_kws.copy()
        kws['row_colors'] = pd.DataFrame({'row_1': list(self.row_colors),
                                          'row_2': list(self.row_colors)},
                                         index=self.df_norm.index,
                                         columns=['row_1', 'row_2'])
        kws['col_colors'] = pd.DataFrame({'col_1': list(self.col_colors),
                                          'col_2': list(self.col_colors)},
                                         index=self.df_norm.columns,
                                         columns=['col_1', 'col_2'])

        cm = mat.clustermap(self.df_norm, **kws)

        row_labels = [l.get_text() for l in
                      cm.ax_row_colors.get_xticklabels()]
        assert cm.row_color_labels == ['row_1', 'row_2']
        assert row_labels == cm.row_color_labels

        col_labels = [l.get_text() for l in
                      cm.ax_col_colors.get_yticklabels()]
        assert cm.col_color_labels == ['col_1', 'col_2']
        assert col_labels == cm.col_color_labels

    def test_row_col_colors_df_shuffled(self):
        # Tests if colors are properly matched, even if given in wrong order

        m, n = self.df_norm.shape
        shuffled_inds = [self.df_norm.index[i] for i in
                         list(range(0, m, 2)) + list(range(1, m, 2))]
        shuffled_cols = [self.df_norm.columns[i] for i in
                         list(range(0, n, 2)) + list(range(1, n, 2))]

        kws = self.default_kws.copy()

        row_colors = pd.DataFrame({'row_annot': list(self.row_colors)},
                                  index=self.df_norm.index)
        kws['row_colors'] = row_colors.loc[shuffled_inds]

        col_colors = pd.DataFrame({'col_annot': list(self.col_colors)},
                                  index=self.df_norm.columns)
        kws['col_colors'] = col_colors.loc[shuffled_cols]

        cm = mat.clustermap(self.df_norm, **kws)
        assert list(cm.col_colors)[0] == list(self.col_colors)
        assert list(cm.row_colors)[0] == list(self.row_colors)

    def test_row_col_colors_df_missing(self):
        kws = self.default_kws.copy()
        row_colors = pd.DataFrame({'row_annot': list(self.row_colors)},
                                  index=self.df_norm.index)
        kws['row_colors'] = row_colors.drop(self.df_norm.index[0])

        col_colors = pd.DataFrame({'col_annot': list(self.col_colors)},
                                  index=self.df_norm.columns)
        kws['col_colors'] = col_colors.drop(self.df_norm.columns[0])

        cm = mat.clustermap(self.df_norm, **kws)

        assert list(cm.col_colors)[0] == [(1.0, 1.0, 1.0)] + list(self.col_colors[1:])
        assert list(cm.row_colors)[0] == [(1.0, 1.0, 1.0)] + list(self.row_colors[1:])

    def test_row_col_colors_df_one_axis(self):
        # Test case with only row annotation.
        kws1 = self.default_kws.copy()
        kws1['row_colors'] = pd.DataFrame({'row_1': list(self.row_colors),
                                           'row_2': list(self.row_colors)},
                                          index=self.df_norm.index,
                                          columns=['row_1', 'row_2'])

        cm1 = mat.clustermap(self.df_norm, **kws1)

        row_labels = [l.get_text() for l in
                      cm1.ax_row_colors.get_xticklabels()]
        assert cm1.row_color_labels == ['row_1', 'row_2']
        assert row_labels == cm1.row_color_labels

        # Test case with only col annotation.
        kws2 = self.default_kws.copy()
        kws2['col_colors'] = pd.DataFrame({'col_1': list(self.col_colors),
                                           'col_2': list(self.col_colors)},
                                          index=self.df_norm.columns,
                                          columns=['col_1', 'col_2'])

        cm2 = mat.clustermap(self.df_norm, **kws2)

        col_labels = [l.get_text() for l in
                      cm2.ax_col_colors.get_yticklabels()]
        assert cm2.col_color_labels == ['col_1', 'col_2']
        assert col_labels == cm2.col_color_labels

    def test_row_col_colors_series(self):
        kws = self.default_kws.copy()
        kws['row_colors'] = pd.Series(list(self.row_colors), name='row_annot',
                                      index=self.df_norm.index)
        kws['col_colors'] = pd.Series(list(self.col_colors), name='col_annot',
                                      index=self.df_norm.columns)

        cm = mat.clustermap(self.df_norm, **kws)

        row_labels = [l.get_text() for l in cm.ax_row_colors.get_xticklabels()]
        assert cm.row_color_labels == ['row_annot']
        assert row_labels == cm.row_color_labels

        col_labels = [l.get_text() for l in cm.ax_col_colors.get_yticklabels()]
        assert cm.col_color_labels == ['col_annot']
        assert col_labels == cm.col_color_labels

    def test_row_col_colors_series_shuffled(self):
        # Tests if colors are properly matched, even if given in wrong order

        m, n = self.df_norm.shape
        shuffled_inds = [self.df_norm.index[i] for i in
                         list(range(0, m, 2)) + list(range(1, m, 2))]
        shuffled_cols = [self.df_norm.columns[i] for i in
                         list(range(0, n, 2)) + list(range(1, n, 2))]

        kws = self.default_kws.copy()

        row_colors = pd.Series(list(self.row_colors), name='row_annot',
                               index=self.df_norm.index)
        kws['row_colors'] = row_colors.loc[shuffled_inds]

        col_colors = pd.Series(list(self.col_colors), name='col_annot',
                               index=self.df_norm.columns)
        kws['col_colors'] = col_colors.loc[shuffled_cols]

        cm = mat.clustermap(self.df_norm, **kws)

        assert list(cm.col_colors) == list(self.col_colors)
        assert list(cm.row_colors) == list(self.row_colors)

    def test_row_col_colors_series_missing(self):
        kws = self.default_kws.copy()
        row_colors = pd.Series(list(self.row_colors), name='row_annot',
                               index=self.df_norm.index)
        kws['row_colors'] = row_colors.drop(self.df_norm.index[0])

        col_colors = pd.Series(list(self.col_colors), name='col_annot',
                               index=self.df_norm.columns)
        kws['col_colors'] = col_colors.drop(self.df_norm.columns[0])

        cm = mat.clustermap(self.df_norm, **kws)
        assert list(cm.col_colors) == [(1.0, 1.0, 1.0)] + list(self.col_colors[1:])
        assert list(cm.row_colors) == [(1.0, 1.0, 1.0)] + list(self.row_colors[1:])

    def test_row_col_colors_ignore_heatmap_kwargs(self):

        g = mat.clustermap(self.rs.uniform(0, 200, self.df_norm.shape),
                           row_colors=self.row_colors,
                           col_colors=self.col_colors,
                           cmap="Spectral",
                           norm=mpl.colors.LogNorm(),
                           vmax=100)

        assert np.array_equal(
            np.array(self.row_colors)[g.dendrogram_row.reordered_ind],
            g.ax_row_colors.collections[0].get_facecolors()[:, :3]
        )

        assert np.array_equal(
            np.array(self.col_colors)[g.dendrogram_col.reordered_ind],
            g.ax_col_colors.collections[0].get_facecolors()[:, :3]
        )

    def test_row_col_colors_raise_on_mixed_index_types(self):

        row_colors = pd.Series(
            list(self.row_colors), name="row_annot", index=self.df_norm.index
        )

        col_colors = pd.Series(
            list(self.col_colors), name="col_annot", index=self.df_norm.columns
        )

        with pytest.raises(TypeError):
            mat.clustermap(self.x_norm, row_colors=row_colors)

        with pytest.raises(TypeError):
            mat.clustermap(self.x_norm, col_colors=col_colors)

    def test_mask_reorganization(self):

        kws = self.default_kws.copy()
        kws["mask"] = self.df_norm > 0

        g = mat.clustermap(self.df_norm, **kws)
        npt.assert_array_equal(g.data2d.index, g.mask.index)
        npt.assert_array_equal(g.data2d.columns, g.mask.columns)

        npt.assert_array_equal(g.mask.index,
                               self.df_norm.index[
                                   g.dendrogram_row.reordered_ind])
        npt.assert_array_equal(g.mask.columns,
                               self.df_norm.columns[
                                   g.dendrogram_col.reordered_ind])

    def test_ticklabel_reorganization(self):

        kws = self.default_kws.copy()
        xtl = np.arange(self.df_norm.shape[1])
        kws["xticklabels"] = list(xtl)
        ytl = self.letters.loc[:self.df_norm.shape[0]]
        kws["yticklabels"] = ytl

        g = mat.clustermap(self.df_norm, **kws)

        xtl_actual = [t.get_text() for t in g.ax_heatmap.get_xticklabels()]
        ytl_actual = [t.get_text() for t in g.ax_heatmap.get_yticklabels()]

        xtl_want = xtl[g.dendrogram_col.reordered_ind].astype(" g1.ax_col_dendrogram.get_position().height)

        assert (g2.ax_col_colors.get_position().height
                > g1.ax_col_colors.get_position().height)

        assert (g2.ax_heatmap.get_position().height
                < g1.ax_heatmap.get_position().height)

        assert (g2.ax_row_dendrogram.get_position().width
                > g1.ax_row_dendrogram.get_position().width)

        assert (g2.ax_row_colors.get_position().width
                > g1.ax_row_colors.get_position().width)

        assert (g2.ax_heatmap.get_position().width
                < g1.ax_heatmap.get_position().width)

        kws1 = self.default_kws.copy()
        kws1.update(col_colors=self.col_colors)
        kws2 = kws1.copy()
        kws2.update(col_colors=[self.col_colors, self.col_colors])

        g1 = mat.clustermap(self.df_norm, **kws1)
        g2 = mat.clustermap(self.df_norm, **kws2)

        assert (g2.ax_col_colors.get_position().height
                > g1.ax_col_colors.get_position().height)

        kws1 = self.default_kws.copy()
        kws1.update(dendrogram_ratio=(.2, .2))

        kws2 = kws1.copy()
        kws2.update(dendrogram_ratio=(.2, .3))

        g1 = mat.clustermap(self.df_norm, **kws1)
        g2 = mat.clustermap(self.df_norm, **kws2)

        # Fails on pinned matplotlib?
        # assert (g2.ax_row_dendrogram.get_position().width
        #         == g1.ax_row_dendrogram.get_position().width)
        assert g1.gs.get_width_ratios() == g2.gs.get_width_ratios()

        assert (g2.ax_col_dendrogram.get_position().height
                > g1.ax_col_dendrogram.get_position().height)

    def test_cbar_pos(self):

        kws = self.default_kws.copy()
        kws["cbar_pos"] = (.2, .1, .4, .3)

        g = mat.clustermap(self.df_norm, **kws)
        pos = g.ax_cbar.get_position()
        assert pytest.approx(tuple(pos.p0)) == kws["cbar_pos"][:2]
        assert pytest.approx(pos.width) == kws["cbar_pos"][2]
        assert pytest.approx(pos.height) == kws["cbar_pos"][3]

        kws["cbar_pos"] = None
        g = mat.clustermap(self.df_norm, **kws)
        assert g.ax_cbar is None

    def test_square_warning(self):

        kws = self.default_kws.copy()
        g1 = mat.clustermap(self.df_norm, **kws)

        with pytest.warns(UserWarning):
            kws["square"] = True
            g2 = mat.clustermap(self.df_norm, **kws)

        g1_shape = g1.ax_heatmap.get_position().get_points()
        g2_shape = g2.ax_heatmap.get_position().get_points()
        assert np.array_equal(g1_shape, g2_shape)

    def test_clustermap_annotation(self):

        g = mat.clustermap(self.df_norm, annot=True, fmt=".1f")
        for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts):
            assert text.get_text() == f"{val:.1f}"

        g = mat.clustermap(self.df_norm, annot=self.df_norm, fmt=".1f")
        for val, text in zip(np.asarray(g.data2d).flat, g.ax_heatmap.texts):
            assert text.get_text() == f"{val:.1f}"

    def test_tree_kws(self):

        rgb = (1, .5, .2)
        g = mat.clustermap(self.df_norm, tree_kws=dict(color=rgb))
        for ax in [g.ax_col_dendrogram, g.ax_row_dendrogram]:
            tree, = ax.collections
            assert tuple(tree.get_color().squeeze())[:3] == rgb


if _no_scipy:

    def test_required_scipy_errors():

        x = np.random.normal(0, 1, (10, 10))

        with pytest.raises(RuntimeError):
            mat.clustermap(x)

        with pytest.raises(RuntimeError):
            mat.ClusterGrid(x)

        with pytest.raises(RuntimeError):
            mat.dendrogram(x)


================================================
FILE: tests/test_miscplot.py
================================================
import matplotlib.pyplot as plt

from seaborn import miscplot as misc
from seaborn.palettes import color_palette
from .test_utils import _network


class TestPalPlot:
    """Test the function that visualizes a color palette."""
    def test_palplot_size(self):

        pal4 = color_palette("husl", 4)
        misc.palplot(pal4)
        size4 = plt.gcf().get_size_inches()
        assert tuple(size4) == (4, 1)

        pal5 = color_palette("husl", 5)
        misc.palplot(pal5)
        size5 = plt.gcf().get_size_inches()
        assert tuple(size5) == (5, 1)

        palbig = color_palette("husl", 3)
        misc.palplot(palbig, 2)
        sizebig = plt.gcf().get_size_inches()
        assert tuple(sizebig) == (6, 2)


class TestDogPlot:

    @_network(url="https://github.com/mwaskom/seaborn-data")
    def test_dogplot(self):
        misc.dogplot()
        ax = plt.gca()
        assert len(ax.images) == 1


================================================
FILE: tests/test_objects.py
================================================
import seaborn.objects
from seaborn._core.plot import Plot
from seaborn._core.moves import Move
from seaborn._core.scales import Scale
from seaborn._marks.base import Mark
from seaborn._stats.base import Stat


def test_objects_namespace():

    for name in dir(seaborn.objects):
        if not name.startswith("__"):
            obj = getattr(seaborn.objects, name)
            assert issubclass(obj, (Plot, Mark, Stat, Move, Scale))


================================================
FILE: tests/test_palettes.py
================================================
import colorsys
import numpy as np
import matplotlib as mpl

import pytest
import numpy.testing as npt

from seaborn import palettes, utils, rcmod
from seaborn.external import husl
from seaborn._compat import get_colormap
from seaborn.colors import xkcd_rgb, crayons


class TestColorPalettes:

    def test_current_palette(self):

        pal = palettes.color_palette(["red", "blue", "green"])
        rcmod.set_palette(pal)
        assert pal == utils.get_color_cycle()
        rcmod.set()

    def test_palette_context(self):

        default_pal = palettes.color_palette()
        context_pal = palettes.color_palette("muted")

        with palettes.color_palette(context_pal):
            assert utils.get_color_cycle() == context_pal

        assert utils.get_color_cycle() == default_pal

    def test_big_palette_context(self):

        original_pal = palettes.color_palette("deep", n_colors=8)
        context_pal = palettes.color_palette("husl", 10)

        rcmod.set_palette(original_pal)
        with palettes.color_palette(context_pal, 10):
            assert utils.get_color_cycle() == context_pal

        assert utils.get_color_cycle() == original_pal

        # Reset default
        rcmod.set()

    def test_palette_size(self):

        pal = palettes.color_palette("deep")
        assert len(pal) == palettes.QUAL_PALETTE_SIZES["deep"]

        pal = palettes.color_palette("pastel6")
        assert len(pal) == palettes.QUAL_PALETTE_SIZES["pastel6"]

        pal = palettes.color_palette("Set3")
        assert len(pal) == palettes.QUAL_PALETTE_SIZES["Set3"]

        pal = palettes.color_palette("husl")
        assert len(pal) == 6

        pal = palettes.color_palette("Greens")
        assert len(pal) == 6

    def test_seaborn_palettes(self):

        pals = "deep", "muted", "pastel", "bright", "dark", "colorblind"
        for name in pals:
            full = palettes.color_palette(name, 10).as_hex()
            short = palettes.color_palette(name + "6", 6).as_hex()
            b, _, g, r, m, _, _, _, y, c = full
            assert [b, g, r, m, y, c] == list(short)

    def test_hls_palette(self):

        pal1 = palettes.hls_palette()
        pal2 = palettes.color_palette("hls")
        npt.assert_array_equal(pal1, pal2)

        cmap1 = palettes.hls_palette(as_cmap=True)
        cmap2 = palettes.color_palette("hls", as_cmap=True)
        npt.assert_array_equal(cmap1([.2, .8]), cmap2([.2, .8]))

    def test_husl_palette(self):

        pal1 = palettes.husl_palette()
        pal2 = palettes.color_palette("husl")
        npt.assert_array_equal(pal1, pal2)

        cmap1 = palettes.husl_palette(as_cmap=True)
        cmap2 = palettes.color_palette("husl", as_cmap=True)
        npt.assert_array_equal(cmap1([.2, .8]), cmap2([.2, .8]))

    def test_mpl_palette(self):

        pal1 = palettes.mpl_palette("Reds")
        pal2 = palettes.color_palette("Reds")
        npt.assert_array_equal(pal1, pal2)

        cmap1 = get_colormap("Reds")
        cmap2 = palettes.mpl_palette("Reds", as_cmap=True)
        cmap3 = palettes.color_palette("Reds", as_cmap=True)
        npt.assert_array_equal(cmap1, cmap2)
        npt.assert_array_equal(cmap1, cmap3)

    def test_mpl_dark_palette(self):

        mpl_pal1 = palettes.mpl_palette("Blues_d")
        mpl_pal2 = palettes.color_palette("Blues_d")
        npt.assert_array_equal(mpl_pal1, mpl_pal2)

        mpl_pal1 = palettes.mpl_palette("Blues_r_d")
        mpl_pal2 = palettes.color_palette("Blues_r_d")
        npt.assert_array_equal(mpl_pal1, mpl_pal2)

    def test_bad_palette_name(self):

        with pytest.raises(ValueError):
            palettes.color_palette("IAmNotAPalette")

    def test_terrible_palette_name(self):

        with pytest.raises(ValueError):
            palettes.color_palette("jet")

    def test_bad_palette_colors(self):

        pal = ["red", "blue", "iamnotacolor"]
        with pytest.raises(ValueError):
            palettes.color_palette(pal)

    def test_palette_desat(self):

        pal1 = palettes.husl_palette(6)
        pal1 = [utils.desaturate(c, .5) for c in pal1]
        pal2 = palettes.color_palette("husl", desat=.5)
        npt.assert_array_equal(pal1, pal2)

    def test_palette_is_list_of_tuples(self):

        pal_in = np.array(["red", "blue", "green"])
        pal_out = palettes.color_palette(pal_in, 3)

        assert isinstance(pal_out, list)
        assert isinstance(pal_out[0], tuple)
        assert isinstance(pal_out[0][0], float)
        assert len(pal_out[0]) == 3

    def test_palette_cycles(self):

        deep = palettes.color_palette("deep6")
        double_deep = palettes.color_palette("deep6", 12)
        assert double_deep == deep + deep

    def test_hls_values(self):

        pal1 = palettes.hls_palette(6, h=0)
        pal2 = palettes.hls_palette(6, h=.5)
        pal2 = pal2[3:] + pal2[:3]
        npt.assert_array_almost_equal(pal1, pal2)

        pal_dark = palettes.hls_palette(5, l=.2)  # noqa
        pal_bright = palettes.hls_palette(5, l=.8)  # noqa
        npt.assert_array_less(list(map(sum, pal_dark)),
                              list(map(sum, pal_bright)))

        pal_flat = palettes.hls_palette(5, s=.1)
        pal_bold = palettes.hls_palette(5, s=.9)
        npt.assert_array_less(list(map(np.std, pal_flat)),
                              list(map(np.std, pal_bold)))

    def test_husl_values(self):

        pal1 = palettes.husl_palette(6, h=0)
        pal2 = palettes.husl_palette(6, h=.5)
        pal2 = pal2[3:] + pal2[:3]
        npt.assert_array_almost_equal(pal1, pal2)

        pal_dark = palettes.husl_palette(5, l=.2)  # noqa
        pal_bright = palettes.husl_palette(5, l=.8)  # noqa
        npt.assert_array_less(list(map(sum, pal_dark)),
                              list(map(sum, pal_bright)))

        pal_flat = palettes.husl_palette(5, s=.1)
        pal_bold = palettes.husl_palette(5, s=.9)
        npt.assert_array_less(list(map(np.std, pal_flat)),
                              list(map(np.std, pal_bold)))

    def test_cbrewer_qual(self):

        pal_short = palettes.mpl_palette("Set1", 4)
        pal_long = palettes.mpl_palette("Set1", 6)
        assert pal_short == pal_long[:4]

        pal_full = palettes.mpl_palette("Set2", 8)
        pal_long = palettes.mpl_palette("Set2", 10)
        assert pal_full == pal_long[:8]

    def test_mpl_reversal(self):

        pal_forward = palettes.mpl_palette("BuPu", 6)
        pal_reverse = palettes.mpl_palette("BuPu_r", 6)
        npt.assert_array_almost_equal(pal_forward, pal_reverse[::-1])

    def test_rgb_from_hls(self):

        color = .5, .8, .4
        rgb_got = palettes._color_to_rgb(color, "hls")
        rgb_want = colorsys.hls_to_rgb(*color)
        assert rgb_got == rgb_want

    def test_rgb_from_husl(self):

        color = 120, 50, 40
        rgb_got = palettes._color_to_rgb(color, "husl")
        rgb_want = tuple(husl.husl_to_rgb(*color))
        assert rgb_got == rgb_want

        for h in range(0, 360):
            color = h, 100, 100
            rgb = palettes._color_to_rgb(color, "husl")
            assert min(rgb) >= 0
            assert max(rgb) <= 1

    def test_rgb_from_xkcd(self):

        color = "dull red"
        rgb_got = palettes._color_to_rgb(color, "xkcd")
        rgb_want = mpl.colors.to_rgb(xkcd_rgb[color])
        assert rgb_got == rgb_want

    def test_light_palette(self):

        n = 4
        pal_forward = palettes.light_palette("red", n)
        pal_reverse = palettes.light_palette("red", n, reverse=True)
        assert np.allclose(pal_forward, pal_reverse[::-1])

        red = mpl.colors.colorConverter.to_rgb("red")
        assert pal_forward[-1] == red

        pal_f_from_string = palettes.color_palette("light:red", n)
        assert pal_forward[3] == pal_f_from_string[3]

        pal_r_from_string = palettes.color_palette("light:red_r", n)
        assert pal_reverse[3] == pal_r_from_string[3]

        pal_cmap = palettes.light_palette("blue", as_cmap=True)
        assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)

        pal_cmap_from_string = palettes.color_palette("light:blue", as_cmap=True)
        assert pal_cmap(.8) == pal_cmap_from_string(.8)

        pal_cmap = palettes.light_palette("blue", as_cmap=True, reverse=True)
        pal_cmap_from_string = palettes.color_palette("light:blue_r", as_cmap=True)
        assert pal_cmap(.8) == pal_cmap_from_string(.8)

    def test_dark_palette(self):

        n = 4
        pal_forward = palettes.dark_palette("red", n)
        pal_reverse = palettes.dark_palette("red", n, reverse=True)
        assert np.allclose(pal_forward, pal_reverse[::-1])

        red = mpl.colors.colorConverter.to_rgb("red")
        assert pal_forward[-1] == red

        pal_f_from_string = palettes.color_palette("dark:red", n)
        assert pal_forward[3] == pal_f_from_string[3]

        pal_r_from_string = palettes.color_palette("dark:red_r", n)
        assert pal_reverse[3] == pal_r_from_string[3]

        pal_cmap = palettes.dark_palette("blue", as_cmap=True)
        assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)

        pal_cmap_from_string = palettes.color_palette("dark:blue", as_cmap=True)
        assert pal_cmap(.8) == pal_cmap_from_string(.8)

        pal_cmap = palettes.dark_palette("blue", as_cmap=True, reverse=True)
        pal_cmap_from_string = palettes.color_palette("dark:blue_r", as_cmap=True)
        assert pal_cmap(.8) == pal_cmap_from_string(.8)

    def test_diverging_palette(self):

        h_neg, h_pos = 100, 200
        sat, lum = 70, 50
        args = h_neg, h_pos, sat, lum

        n = 12
        pal = palettes.diverging_palette(*args, n=n)
        neg_pal = palettes.light_palette((h_neg, sat, lum), int(n // 2),
                                         input="husl")
        pos_pal = palettes.light_palette((h_pos, sat, lum), int(n // 2),
                                         input="husl")
        assert len(pal) == n
        assert pal[0] == neg_pal[-1]
        assert pal[-1] == pos_pal[-1]

        pal_dark = palettes.diverging_palette(*args, n=n, center="dark")
        assert np.mean(pal[int(n / 2)]) > np.mean(pal_dark[int(n / 2)])

        pal_cmap = palettes.diverging_palette(*args, as_cmap=True)
        assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)

    def test_blend_palette(self):

        colors = ["red", "yellow", "white"]
        pal_cmap = palettes.blend_palette(colors, as_cmap=True)
        assert isinstance(pal_cmap, mpl.colors.LinearSegmentedColormap)

        colors = ["red", "blue"]
        pal = palettes.blend_palette(colors)
        pal_str = "blend:" + ",".join(colors)
        pal_from_str = palettes.color_palette(pal_str)
        assert pal == pal_from_str

    def test_cubehelix_against_matplotlib(self):

        x = np.linspace(0, 1, 8)
        mpl_pal = mpl.cm.cubehelix(x)[:, :3].tolist()

        sns_pal = palettes.cubehelix_palette(8, start=0.5, rot=-1.5, hue=1,
                                             dark=0, light=1, reverse=True)

        assert sns_pal == mpl_pal

    def test_cubehelix_n_colors(self):

        for n in [3, 5, 8]:
            pal = palettes.cubehelix_palette(n)
            assert len(pal) == n

    def test_cubehelix_reverse(self):

        pal_forward = palettes.cubehelix_palette()
        pal_reverse = palettes.cubehelix_palette(reverse=True)
        assert pal_forward == pal_reverse[::-1]

    def test_cubehelix_cmap(self):

        cmap = palettes.cubehelix_palette(as_cmap=True)
        assert isinstance(cmap, mpl.colors.ListedColormap)
        pal = palettes.cubehelix_palette()
        x = np.linspace(0, 1, 6)
        npt.assert_array_equal(cmap(x)[:, :3], pal)

        cmap_rev = palettes.cubehelix_palette(as_cmap=True, reverse=True)
        x = np.linspace(0, 1, 6)
        pal_forward = cmap(x).tolist()
        pal_reverse = cmap_rev(x[::-1]).tolist()
        assert pal_forward == pal_reverse

    def test_cubehelix_code(self):

        color_palette = palettes.color_palette
        cubehelix_palette = palettes.cubehelix_palette

        pal1 = color_palette("ch:", 8)
        pal2 = color_palette(cubehelix_palette(8))
        assert pal1 == pal2

        pal1 = color_palette("ch:.5, -.25,hue = .5,light=.75", 8)
        pal2 = color_palette(cubehelix_palette(8, .5, -.25, hue=.5, light=.75))
        assert pal1 == pal2

        pal1 = color_palette("ch:h=1,r=.5", 9)
        pal2 = color_palette(cubehelix_palette(9, hue=1, rot=.5))
        assert pal1 == pal2

        pal1 = color_palette("ch:_r", 6)
        pal2 = color_palette(cubehelix_palette(6, reverse=True))
        assert pal1 == pal2

        pal1 = color_palette("ch:_r", as_cmap=True)
        pal2 = cubehelix_palette(6, reverse=True, as_cmap=True)
        assert pal1(.5) == pal2(.5)

    def test_xkcd_palette(self):

        names = list(xkcd_rgb.keys())[10:15]
        colors = palettes.xkcd_palette(names)
        for name, color in zip(names, colors):
            as_hex = mpl.colors.rgb2hex(color)
            assert as_hex == xkcd_rgb[name]

    def test_crayon_palette(self):

        names = list(crayons.keys())[10:15]
        colors = palettes.crayon_palette(names)
        for name, color in zip(names, colors):
            as_hex = mpl.colors.rgb2hex(color)
            assert as_hex == crayons[name].lower()

    def test_color_codes(self):

        palettes.set_color_codes("deep")
        colors = palettes.color_palette("deep6") + [".1"]
        for code, color in zip("bgrmyck", colors):
            rgb_want = mpl.colors.colorConverter.to_rgb(color)
            rgb_got = mpl.colors.colorConverter.to_rgb(code)
            assert rgb_want == rgb_got
        palettes.set_color_codes("reset")

        with pytest.raises(ValueError):
            palettes.set_color_codes("Set1")

    def test_as_hex(self):

        pal = palettes.color_palette("deep")
        for rgb, hex in zip(pal, pal.as_hex()):
            assert mpl.colors.rgb2hex(rgb) == hex

    def test_preserved_palette_length(self):

        pal_in = palettes.color_palette("Set1", 10)
        pal_out = palettes.color_palette(pal_in)
        assert pal_in == pal_out

    def test_html_repr(self):

        pal = palettes.color_palette()
        html = pal._repr_html_()
        for color in pal.as_hex():
            assert color in html

    def test_colormap_display_patch(self):

        orig_repr_png = getattr(mpl.colors.Colormap, "_repr_png_", None)
        orig_repr_html = getattr(mpl.colors.Colormap, "_repr_html_", None)

        try:
            palettes._patch_colormap_display()
            cmap = mpl.cm.Reds
            assert cmap._repr_html_().startswith('Reds')
        finally:
            if orig_repr_png is not None:
                mpl.colors.Colormap._repr_png_ = orig_repr_png
            if orig_repr_html is not None:
                mpl.colors.Colormap._repr_html_ = orig_repr_html


================================================
FILE: tests/test_rcmod.py
================================================
import pytest
import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy.testing as npt

from seaborn import rcmod, palettes, utils


def has_verdana():
     yhat_log[0]
        assert yhat_log[20] > yhat_lin[20]
        assert yhat_lin[90] > yhat_log[90]

    @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
    def test_regress_n_boot(self):

        p = lm._RegressionPlotter("x", "y", data=self.df, n_boot=self.n_boot)

        # Fast (linear algebra) version
        _, boots_fast = p.fit_fast(self.grid)
        npt.assert_equal(boots_fast.shape, (self.n_boot, self.grid.size))

        # Slower (np.polyfit) version
        _, boots_poly = p.fit_poly(self.grid, 1)
        npt.assert_equal(boots_poly.shape, (self.n_boot, self.grid.size))

        # Slowest (statsmodels) version
        _, boots_smod = p.fit_statsmodels(self.grid, smlm.OLS)
        npt.assert_equal(boots_smod.shape, (self.n_boot, self.grid.size))

    @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
    def test_regress_without_bootstrap(self):

        p = lm._RegressionPlotter("x", "y", data=self.df,
                                  n_boot=self.n_boot, ci=None)

        # Fast (linear algebra) version
        _, boots_fast = p.fit_fast(self.grid)
        assert boots_fast is None

        # Slower (np.polyfit) version
        _, boots_poly = p.fit_poly(self.grid, 1)
        assert boots_poly is None

        # Slowest (statsmodels) version
        _, boots_smod = p.fit_statsmodels(self.grid, smlm.OLS)
        assert boots_smod is None

    def test_regress_bootstrap_seed(self):

        seed = 200
        p1 = lm._RegressionPlotter("x", "y", data=self.df,
                                   n_boot=self.n_boot, seed=seed)
        p2 = lm._RegressionPlotter("x", "y", data=self.df,
                                   n_boot=self.n_boot, seed=seed)

        _, boots1 = p1.fit_fast(self.grid)
        _, boots2 = p2.fit_fast(self.grid)
        npt.assert_array_equal(boots1, boots2)

    def test_numeric_bins(self):

        p = lm._RegressionPlotter(self.df.x, self.df.y)
        x_binned, bins = p.bin_predictor(self.bins_numeric)
        npt.assert_equal(len(bins), self.bins_numeric)
        npt.assert_array_equal(np.unique(x_binned), bins)

    def test_provided_bins(self):

        p = lm._RegressionPlotter(self.df.x, self.df.y)
        x_binned, bins = p.bin_predictor(self.bins_given)
        npt.assert_array_equal(np.unique(x_binned), self.bins_given)

    def test_bin_results(self):

        p = lm._RegressionPlotter(self.df.x, self.df.y)
        x_binned, bins = p.bin_predictor(self.bins_given)
        assert self.df.x[x_binned == 0].min() > self.df.x[x_binned == -1].max()
        assert self.df.x[x_binned == 1].min() > self.df.x[x_binned == 0].max()

    def test_scatter_data(self):

        p = lm._RegressionPlotter(self.df.x, self.df.y)
        x, y = p.scatter_data
        npt.assert_array_equal(x, self.df.x)
        npt.assert_array_equal(y, self.df.y)

        p = lm._RegressionPlotter(self.df.d, self.df.y)
        x, y = p.scatter_data
        npt.assert_array_equal(x, self.df.d)
        npt.assert_array_equal(y, self.df.y)

        p = lm._RegressionPlotter(self.df.d, self.df.y, x_jitter=.1)
        x, y = p.scatter_data
        assert (x != self.df.d).any()
        npt.assert_array_less(np.abs(self.df.d - x), np.repeat(.1, len(x)))
        npt.assert_array_equal(y, self.df.y)

        p = lm._RegressionPlotter(self.df.d, self.df.y, y_jitter=.05)
        x, y = p.scatter_data
        npt.assert_array_equal(x, self.df.d)
        npt.assert_array_less(np.abs(self.df.y - y), np.repeat(.1, len(y)))

    def test_estimate_data(self):

        p = lm._RegressionPlotter(self.df.d, self.df.y, x_estimator=np.mean)

        x, y, ci = p.estimate_data

        npt.assert_array_equal(x, np.sort(np.unique(self.df.d)))
        npt.assert_array_almost_equal(y, self.df.groupby("d").y.mean())
        npt.assert_array_less(np.array(ci)[:, 0], y)
        npt.assert_array_less(y, np.array(ci)[:, 1])

    def test_estimate_cis(self):

        seed = 123

        p = lm._RegressionPlotter(self.df.d, self.df.y,
                                  x_estimator=np.mean, ci=95, seed=seed)
        _, _, ci_big = p.estimate_data

        p = lm._RegressionPlotter(self.df.d, self.df.y,
                                  x_estimator=np.mean, ci=50, seed=seed)
        _, _, ci_wee = p.estimate_data
        npt.assert_array_less(np.diff(ci_wee), np.diff(ci_big))

        p = lm._RegressionPlotter(self.df.d, self.df.y,
                                  x_estimator=np.mean, ci=None)
        _, _, ci_nil = p.estimate_data
        npt.assert_array_equal(ci_nil, [None] * len(ci_nil))

    def test_estimate_units(self):

        # Seed the RNG locally
        seed = 345

        p = lm._RegressionPlotter("x", "y", data=self.df,
                                  units="s", seed=seed, x_bins=3)
        _, _, ci_big = p.estimate_data
        ci_big = np.diff(ci_big, axis=1)

        p = lm._RegressionPlotter("x", "y", data=self.df, seed=seed, x_bins=3)
        _, _, ci_wee = p.estimate_data
        ci_wee = np.diff(ci_wee, axis=1)

        npt.assert_array_less(ci_wee, ci_big)

    def test_partial(self):

        x = self.rs.randn(100)
        y = x + self.rs.randn(100)
        z = x + self.rs.randn(100)

        p = lm._RegressionPlotter(y, z)
        _, r_orig = np.corrcoef(p.x, p.y)[0]

        p = lm._RegressionPlotter(y, z, y_partial=x)
        _, r_semipartial = np.corrcoef(p.x, p.y)[0]
        assert r_semipartial < r_orig

        p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x)
        _, r_partial = np.corrcoef(p.x, p.y)[0]
        assert r_partial < r_orig

        x = pd.Series(x)
        y = pd.Series(y)
        p = lm._RegressionPlotter(y, z, x_partial=x, y_partial=x)
        _, r_partial = np.corrcoef(p.x, p.y)[0]
        assert r_partial < r_orig

    @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
    def test_logistic_regression(self):

        p = lm._RegressionPlotter("x", "c", data=self.df,
                                  logistic=True, n_boot=self.n_boot)
        _, yhat, _ = p.fit_regression(x_range=(-3, 3))
        npt.assert_array_less(yhat, 1)
        npt.assert_array_less(0, yhat)

    @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
    def test_logistic_perfect_separation(self):

        y = self.df.x > self.df.x.mean()
        p = lm._RegressionPlotter("x", y, data=self.df,
                                  logistic=True, n_boot=10)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", RuntimeWarning)
            _, yhat, _ = p.fit_regression(x_range=(-3, 3))
        assert np.isnan(yhat).all()

    @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
    def test_robust_regression(self):

        p_ols = lm._RegressionPlotter("x", "y", data=self.df,
                                      n_boot=self.n_boot)
        _, ols_yhat, _ = p_ols.fit_regression(x_range=(-3, 3))

        p_robust = lm._RegressionPlotter("x", "y", data=self.df,
                                         robust=True, n_boot=self.n_boot)
        _, robust_yhat, _ = p_robust.fit_regression(x_range=(-3, 3))

        assert len(ols_yhat) == len(robust_yhat)

    @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
    def test_lowess_regression(self):

        p = lm._RegressionPlotter("x", "y", data=self.df, lowess=True)
        grid, yhat, err_bands = p.fit_regression(x_range=(-3, 3))

        assert len(grid) == len(yhat)
        assert err_bands is None

    def test_regression_options(self):

        with pytest.raises(ValueError):
            lm._RegressionPlotter("x", "y", data=self.df,
                                  lowess=True, order=2)

        with pytest.raises(ValueError):
            lm._RegressionPlotter("x", "y", data=self.df,
                                  lowess=True, logistic=True)

    def test_regression_limits(self):

        f, ax = plt.subplots()
        ax.scatter(self.df.x, self.df.y)
        p = lm._RegressionPlotter("x", "y", data=self.df)
        grid, _, _ = p.fit_regression(ax)
        xlim = ax.get_xlim()
        assert grid.min() == xlim[0]
        assert grid.max() == xlim[1]

        p = lm._RegressionPlotter("x", "y", data=self.df, truncate=True)
        grid, _, _ = p.fit_regression()
        assert grid.min() == self.df.x.min()
        assert grid.max() == self.df.x.max()


class TestRegressionPlots:

    rs = np.random.RandomState(56)
    df = pd.DataFrame(dict(x=rs.randn(90),
                           y=rs.randn(90) + 5,
                           z=rs.randint(0, 1, 90),
                           g=np.repeat(list("abc"), 30),
                           h=np.tile(list("xy"), 45),
                           u=np.tile(np.arange(6), 15)))
    bw_err = rs.randn(6)[df.u.values]
    df.y += bw_err

    def test_regplot_basic(self):

        f, ax = plt.subplots()
        lm.regplot(x="x", y="y", data=self.df)
        assert len(ax.lines) == 1
        assert len(ax.collections) == 2

        x, y = ax.collections[0].get_offsets().T
        npt.assert_array_equal(x, self.df.x)
        npt.assert_array_equal(y, self.df.y)

    def test_regplot_selective(self):

        f, ax = plt.subplots()
        ax = lm.regplot(x="x", y="y", data=self.df, scatter=False, ax=ax)
        assert len(ax.lines) == 1
        assert len(ax.collections) == 1
        ax.clear()

        f, ax = plt.subplots()
        ax = lm.regplot(x="x", y="y", data=self.df, fit_reg=False)
        assert len(ax.lines) == 0
        assert len(ax.collections) == 1
        ax.clear()

        f, ax = plt.subplots()
        ax = lm.regplot(x="x", y="y", data=self.df, ci=None)
        assert len(ax.lines) == 1
        assert len(ax.collections) == 1
        ax.clear()

    def test_regplot_scatter_kws_alpha(self):

        f, ax = plt.subplots()
        color = np.array([[0.3, 0.8, 0.5, 0.5]])
        ax = lm.regplot(x="x", y="y", data=self.df,
                        scatter_kws={'color': color})
        assert ax.collections[0]._alpha is None
        assert ax.collections[0]._facecolors[0, 3] == 0.5

        f, ax = plt.subplots()
        color = np.array([[0.3, 0.8, 0.5]])
        ax = lm.regplot(x="x", y="y", data=self.df,
                        scatter_kws={'color': color})
        assert ax.collections[0]._alpha == 0.8

        f, ax = plt.subplots()
        color = np.array([[0.3, 0.8, 0.5]])
        ax = lm.regplot(x="x", y="y", data=self.df,
                        scatter_kws={'color': color, 'alpha': 0.4})
        assert ax.collections[0]._alpha == 0.4

        f, ax = plt.subplots()
        color = 'r'
        ax = lm.regplot(x="x", y="y", data=self.df,
                        scatter_kws={'color': color})
        assert ax.collections[0]._alpha == 0.8

        f, ax = plt.subplots()
        alpha = .3
        ax = lm.regplot(x="x", y="y", data=self.df,
                        x_bins=5, fit_reg=False,
                        scatter_kws={"alpha": alpha})
        for line in ax.lines:
            assert line.get_alpha() == alpha

    def test_regplot_binned(self):

        ax = lm.regplot(x="x", y="y", data=self.df, x_bins=5)
        assert len(ax.lines) == 6
        assert len(ax.collections) == 2

    def test_lmplot_no_data(self):

        with pytest.raises(TypeError):
            # keyword argument `data` is required
            lm.lmplot(x="x", y="y")

    def test_lmplot_basic(self):

        g = lm.lmplot(x="x", y="y", data=self.df)
        ax = g.axes[0, 0]
        assert len(ax.lines) == 1
        assert len(ax.collections) == 2

        x, y = ax.collections[0].get_offsets().T
        npt.assert_array_equal(x, self.df.x)
        npt.assert_array_equal(y, self.df.y)

    def test_lmplot_hue(self):

        g = lm.lmplot(x="x", y="y", data=self.df, hue="h")
        ax = g.axes[0, 0]

        assert len(ax.lines) == 2
        assert len(ax.collections) == 4

    def test_lmplot_markers(self):

        g1 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers="s")
        assert g1.hue_kws == {"marker": ["s", "s"]}

        g2 = lm.lmplot(x="x", y="y", data=self.df, hue="h", markers=["o", "s"])
        assert g2.hue_kws == {"marker": ["o", "s"]}

        with pytest.raises(ValueError):
            lm.lmplot(x="x", y="y", data=self.df, hue="h",
                      markers=["o", "s", "d"])

    def test_lmplot_marker_linewidths(self):

        g = lm.lmplot(x="x", y="y", data=self.df, hue="h",
                      fit_reg=False, markers=["o", "+"])
        c = g.axes[0, 0].collections
        assert c[1].get_linewidths()[0] == mpl.rcParams["lines.linewidth"]

    def test_lmplot_facets(self):

        g = lm.lmplot(x="x", y="y", data=self.df, row="g", col="h")
        assert g.axes.shape == (3, 2)

        g = lm.lmplot(x="x", y="y", data=self.df, col="u", col_wrap=4)
        assert g.axes.shape == (6,)

        g = lm.lmplot(x="x", y="y", data=self.df, hue="h", col="u")
        assert g.axes.shape == (1, 6)

    def test_lmplot_hue_col_nolegend(self):

        g = lm.lmplot(x="x", y="y", data=self.df, col="h", hue="h")
        assert g._legend is None

    def test_lmplot_scatter_kws(self):

        g = lm.lmplot(x="x", y="y", hue="h", data=self.df, ci=None)
        red_scatter, blue_scatter = g.axes[0, 0].collections

        red, blue = color_palette(n_colors=2)
        npt.assert_array_equal(red, red_scatter.get_facecolors()[0, :3])
        npt.assert_array_equal(blue, blue_scatter.get_facecolors()[0, :3])

    @pytest.mark.parametrize("sharex", [True, False])
    def test_lmplot_facet_truncate(self, sharex):

        g = lm.lmplot(
            data=self.df, x="x", y="y", hue="g", col="h",
            truncate=False, facet_kws=dict(sharex=sharex),
        )

        for ax in g.axes.flat:
            for line in ax.lines:
                xdata = line.get_xdata()
                assert ax.get_xlim() == tuple(xdata[[0, -1]])

    def test_lmplot_sharey(self):

        df = pd.DataFrame(dict(
            x=[0, 1, 2, 0, 1, 2],
            y=[1, -1, 0, -100, 200, 0],
            z=["a", "a", "a", "b", "b", "b"],
        ))

        with pytest.warns(UserWarning):
            g = lm.lmplot(data=df, x="x", y="y", col="z", sharey=False)
        ax1, ax2 = g.axes.flat
        assert ax1.get_ylim()[0] > ax2.get_ylim()[0]
        assert ax1.get_ylim()[1] < ax2.get_ylim()[1]

    def test_lmplot_facet_kws(self):

        xlim = -4, 20
        g = lm.lmplot(
            data=self.df, x="x", y="y", col="h", facet_kws={"xlim": xlim}
        )
        for ax in g.axes.flat:
            assert ax.get_xlim() == xlim

    def test_residplot(self):

        x, y = self.df.x, self.df.y
        ax = lm.residplot(x=x, y=y)

        resid = y - np.polyval(np.polyfit(x, y, 1), x)
        x_plot, y_plot = ax.collections[0].get_offsets().T

        npt.assert_array_equal(x, x_plot)
        npt.assert_array_almost_equal(resid, y_plot)

    @pytest.mark.skipif(_no_statsmodels, reason="no statsmodels")
    def test_residplot_lowess(self):

        ax = lm.residplot(x="x", y="y", data=self.df, lowess=True)
        assert len(ax.lines) == 2

        x, y = ax.lines[1].get_xydata().T
        npt.assert_array_equal(x, np.sort(self.df.x))

    @pytest.mark.parametrize("option", ["robust", "lowess"])
    @pytest.mark.skipif(not _no_statsmodels, reason="statsmodels installed")
    def test_residplot_statsmodels_missing_errors(self, long_df, option):
        with pytest.raises(RuntimeError, match=rf"`{option}=True` requires"):
            lm.residplot(long_df, x="x", y="y", **{option: True})

    def test_three_point_colors(self):

        x, y = np.random.randn(2, 3)
        ax = lm.regplot(x=x, y=y, color=(1, 0, 0))
        color = ax.collections[0].get_facecolors()
        npt.assert_almost_equal(color[0, :3],
                                (1, 0, 0))

    def test_regplot_xlim(self):

        f, ax = plt.subplots()
        x, y1, y2 = np.random.randn(3, 50)
        lm.regplot(x=x, y=y1, truncate=False)
        lm.regplot(x=x, y=y2, truncate=False)
        line1, line2 = ax.lines
        assert np.array_equal(line1.get_xdata(), line2.get_xdata())


================================================
FILE: tests/test_relational.py
================================================
from itertools import product
import warnings

import numpy as np
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import same_color, to_rgba

import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn.palettes import color_palette
from seaborn._base import categorical_order, unique_markers

from seaborn.relational import (
    _RelationalPlotter,
    _LinePlotter,
    _ScatterPlotter,
    relplot,
    lineplot,
    scatterplot
)

from seaborn.utils import _draw_figure, _version_predates
from seaborn._compat import get_colormap, get_legend_handles
from seaborn._testing import assert_plots_equal


@pytest.fixture(params=[
    dict(x="x", y="y"),
    dict(x="t", y="y"),
    dict(x="a", y="y"),
    dict(x="x", y="y", hue="y"),
    dict(x="x", y="y", hue="a"),
    dict(x="x", y="y", size="a"),
    dict(x="x", y="y", style="a"),
    dict(x="x", y="y", hue="s"),
    dict(x="x", y="y", size="s"),
    dict(x="x", y="y", style="s"),
    dict(x="x", y="y", hue="a", style="a"),
    dict(x="x", y="y", hue="a", size="b", style="b"),
])
def long_semantics(request):
    return request.param


class Helpers:

    @pytest.fixture
    def levels(self, long_df):
        return {var: categorical_order(long_df[var]) for var in ["a", "b"]}

    def scatter_rgbs(self, collections):
        rgbs = []
        for col in collections:
            rgb = tuple(col.get_facecolor().squeeze()[:3])
            rgbs.append(rgb)
        return rgbs

    def paths_equal(self, *args):

        equal = all([len(a) == len(args[0]) for a in args])

        for p1, p2 in zip(*args):
            equal &= np.array_equal(p1.vertices, p2.vertices)
            equal &= np.array_equal(p1.codes, p2.codes)
        return equal


class SharedAxesLevelTests:

    def test_color(self, long_df):

        ax = plt.figure().subplots()
        self.func(data=long_df, x="x", y="y", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C0")

        ax = plt.figure().subplots()
        self.func(data=long_df, x="x", y="y", ax=ax)
        self.func(data=long_df, x="x", y="y", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C1")

        ax = plt.figure().subplots()
        self.func(data=long_df, x="x", y="y", color="C2", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C2")

        ax = plt.figure().subplots()
        self.func(data=long_df, x="x", y="y", c="C2", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C2")


class TestRelationalPlotter(Helpers):

    def test_wide_df_variables(self, wide_df):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_df)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]
        assert len(p.plot_data) == np.prod(wide_df.shape)

        x = p.plot_data["x"]
        expected_x = np.tile(wide_df.index, wide_df.shape[1])
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"]
        expected_y = wide_df.to_numpy().ravel(order="f")
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(wide_df.columns.to_numpy(), wide_df.shape[0])
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] == wide_df.index.name
        assert p.variables["y"] is None
        assert p.variables["hue"] == wide_df.columns.name
        assert p.variables["style"] == wide_df.columns.name

    def test_wide_df_with_nonnumeric_variables(self, long_df):

        p = _RelationalPlotter()
        p.assign_variables(data=long_df)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]

        numeric_df = long_df.select_dtypes("number")

        assert len(p.plot_data) == np.prod(numeric_df.shape)

        x = p.plot_data["x"]
        expected_x = np.tile(numeric_df.index, numeric_df.shape[1])
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"]
        expected_y = numeric_df.to_numpy().ravel(order="f")
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(
            numeric_df.columns.to_numpy(), numeric_df.shape[0]
        )
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] == numeric_df.index.name
        assert p.variables["y"] is None
        assert p.variables["hue"] == numeric_df.columns.name
        assert p.variables["style"] == numeric_df.columns.name

    def test_wide_array_variables(self, wide_array):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_array)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]
        assert len(p.plot_data) == np.prod(wide_array.shape)

        nrow, ncol = wide_array.shape

        x = p.plot_data["x"]
        expected_x = np.tile(np.arange(nrow), ncol)
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"]
        expected_y = wide_array.ravel(order="f")
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(np.arange(ncol), nrow)
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] is None
        assert p.variables["y"] is None
        assert p.variables["hue"] is None
        assert p.variables["style"] is None

    def test_flat_array_variables(self, flat_array):

        p = _RelationalPlotter()
        p.assign_variables(data=flat_array)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y"]
        assert len(p.plot_data) == np.prod(flat_array.shape)

        x = p.plot_data["x"]
        expected_x = np.arange(flat_array.shape[0])
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"]
        expected_y = flat_array
        assert_array_equal(y, expected_y)

        assert p.variables["x"] is None
        assert p.variables["y"] is None

    def test_flat_list_variables(self, flat_list):

        p = _RelationalPlotter()
        p.assign_variables(data=flat_list)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y"]
        assert len(p.plot_data) == len(flat_list)

        x = p.plot_data["x"]
        expected_x = np.arange(len(flat_list))
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"]
        expected_y = flat_list
        assert_array_equal(y, expected_y)

        assert p.variables["x"] is None
        assert p.variables["y"] is None

    def test_flat_series_variables(self, flat_series):

        p = _RelationalPlotter()
        p.assign_variables(data=flat_series)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y"]
        assert len(p.plot_data) == len(flat_series)

        x = p.plot_data["x"]
        expected_x = flat_series.index
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"]
        expected_y = flat_series
        assert_array_equal(y, expected_y)

        assert p.variables["x"] is flat_series.index.name
        assert p.variables["y"] is flat_series.name

    def test_wide_list_of_series_variables(self, wide_list_of_series):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_list_of_series)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]

        chunks = len(wide_list_of_series)
        chunk_size = max(len(l) for l in wide_list_of_series)

        assert len(p.plot_data) == chunks * chunk_size

        index_union = np.unique(
            np.concatenate([s.index for s in wide_list_of_series])
        )

        x = p.plot_data["x"]
        expected_x = np.tile(index_union, chunks)
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"]
        expected_y = np.concatenate([
            s.reindex(index_union) for s in wide_list_of_series
        ])
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        series_names = [s.name for s in wide_list_of_series]
        expected_hue = np.repeat(series_names, chunk_size)
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] is None
        assert p.variables["y"] is None
        assert p.variables["hue"] is None
        assert p.variables["style"] is None

    def test_wide_list_of_arrays_variables(self, wide_list_of_arrays):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_list_of_arrays)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]

        chunks = len(wide_list_of_arrays)
        chunk_size = max(len(l) for l in wide_list_of_arrays)

        assert len(p.plot_data) == chunks * chunk_size

        x = p.plot_data["x"]
        expected_x = np.tile(np.arange(chunk_size), chunks)
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"].dropna()
        expected_y = np.concatenate(wide_list_of_arrays)
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(np.arange(chunks), chunk_size)
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] is None
        assert p.variables["y"] is None
        assert p.variables["hue"] is None
        assert p.variables["style"] is None

    def test_wide_list_of_list_variables(self, wide_list_of_lists):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_list_of_lists)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]

        chunks = len(wide_list_of_lists)
        chunk_size = max(len(l) for l in wide_list_of_lists)

        assert len(p.plot_data) == chunks * chunk_size

        x = p.plot_data["x"]
        expected_x = np.tile(np.arange(chunk_size), chunks)
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"].dropna()
        expected_y = np.concatenate(wide_list_of_lists)
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(np.arange(chunks), chunk_size)
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] is None
        assert p.variables["y"] is None
        assert p.variables["hue"] is None
        assert p.variables["style"] is None

    def test_wide_dict_of_series_variables(self, wide_dict_of_series):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_dict_of_series)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]

        chunks = len(wide_dict_of_series)
        chunk_size = max(len(l) for l in wide_dict_of_series.values())

        assert len(p.plot_data) == chunks * chunk_size

        x = p.plot_data["x"]
        expected_x = np.tile(np.arange(chunk_size), chunks)
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"].dropna()
        expected_y = np.concatenate(list(wide_dict_of_series.values()))
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(list(wide_dict_of_series), chunk_size)
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] is None
        assert p.variables["y"] is None
        assert p.variables["hue"] is None
        assert p.variables["style"] is None

    def test_wide_dict_of_arrays_variables(self, wide_dict_of_arrays):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_dict_of_arrays)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]

        chunks = len(wide_dict_of_arrays)
        chunk_size = max(len(l) for l in wide_dict_of_arrays.values())

        assert len(p.plot_data) == chunks * chunk_size

        x = p.plot_data["x"]
        expected_x = np.tile(np.arange(chunk_size), chunks)
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"].dropna()
        expected_y = np.concatenate(list(wide_dict_of_arrays.values()))
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(list(wide_dict_of_arrays), chunk_size)
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] is None
        assert p.variables["y"] is None
        assert p.variables["hue"] is None
        assert p.variables["style"] is None

    def test_wide_dict_of_lists_variables(self, wide_dict_of_lists):

        p = _RelationalPlotter()
        p.assign_variables(data=wide_dict_of_lists)
        assert p.input_format == "wide"
        assert list(p.variables) == ["x", "y", "hue", "style"]

        chunks = len(wide_dict_of_lists)
        chunk_size = max(len(l) for l in wide_dict_of_lists.values())

        assert len(p.plot_data) == chunks * chunk_size

        x = p.plot_data["x"]
        expected_x = np.tile(np.arange(chunk_size), chunks)
        assert_array_equal(x, expected_x)

        y = p.plot_data["y"].dropna()
        expected_y = np.concatenate(list(wide_dict_of_lists.values()))
        assert_array_equal(y, expected_y)

        hue = p.plot_data["hue"]
        expected_hue = np.repeat(list(wide_dict_of_lists), chunk_size)
        assert_array_equal(hue, expected_hue)

        style = p.plot_data["style"]
        expected_style = expected_hue
        assert_array_equal(style, expected_style)

        assert p.variables["x"] is None
        assert p.variables["y"] is None
        assert p.variables["hue"] is None
        assert p.variables["style"] is None

    def test_relplot_simple(self, long_df):

        g = relplot(data=long_df, x="x", y="y", kind="scatter")
        x, y = g.ax.collections[0].get_offsets().T
        assert_array_equal(x, long_df["x"])
        assert_array_equal(y, long_df["y"])

        g = relplot(data=long_df, x="x", y="y", kind="line")
        x, y = g.ax.lines[0].get_xydata().T
        expected = long_df.groupby("x").y.mean()
        assert_array_equal(x, expected.index)
        assert y == pytest.approx(expected.values)

        with pytest.raises(ValueError):
            g = relplot(data=long_df, x="x", y="y", kind="not_a_kind")

    def test_relplot_complex(self, long_df):

        for sem in ["hue", "size", "style"]:
            g = relplot(data=long_df, x="x", y="y", **{sem: "a"})
            x, y = g.ax.collections[0].get_offsets().T
            assert_array_equal(x, long_df["x"])
            assert_array_equal(y, long_df["y"])

        for sem in ["hue", "size", "style"]:
            g = relplot(
                data=long_df, x="x", y="y", col="c", **{sem: "a"}
            )
            grouped = long_df.groupby("c")
            for (_, grp_df), ax in zip(grouped, g.axes.flat):
                x, y = ax.collections[0].get_offsets().T
                assert_array_equal(x, grp_df["x"])
                assert_array_equal(y, grp_df["y"])

        for sem in ["size", "style"]:
            g = relplot(
                data=long_df, x="x", y="y", hue="b", col="c", **{sem: "a"}
            )
            grouped = long_df.groupby("c")
            for (_, grp_df), ax in zip(grouped, g.axes.flat):
                x, y = ax.collections[0].get_offsets().T
                assert_array_equal(x, grp_df["x"])
                assert_array_equal(y, grp_df["y"])

        for sem in ["hue", "size", "style"]:
            g = relplot(
                data=long_df.sort_values(["c", "b"]),
                x="x", y="y", col="b", row="c", **{sem: "a"}
            )
            grouped = long_df.groupby(["c", "b"])
            for (_, grp_df), ax in zip(grouped, g.axes.flat):
                x, y = ax.collections[0].get_offsets().T
                assert_array_equal(x, grp_df["x"])
                assert_array_equal(y, grp_df["y"])

    @pytest.mark.parametrize("vector_type", ["series", "numpy", "list"])
    def test_relplot_vectors(self, long_df, vector_type):

        semantics = dict(x="x", y="y", hue="f", col="c")
        kws = {key: long_df[val] for key, val in semantics.items()}
        if vector_type == "numpy":
            kws = {k: v.to_numpy() for k, v in kws.items()}
        elif vector_type == "list":
            kws = {k: v.to_list() for k, v in kws.items()}
        g = relplot(data=long_df, **kws)
        grouped = long_df.groupby("c")
        assert len(g.axes_dict) == len(grouped)
        for (_, grp_df), ax in zip(grouped, g.axes.flat):
            x, y = ax.collections[0].get_offsets().T
            assert_array_equal(x, grp_df["x"])
            assert_array_equal(y, grp_df["y"])

    def test_relplot_wide(self, wide_df):

        g = relplot(data=wide_df)
        x, y = g.ax.collections[0].get_offsets().T
        assert_array_equal(y, wide_df.to_numpy().T.ravel())
        assert not g.ax.get_ylabel()

    def test_relplot_hues(self, long_df):

        palette = ["r", "b", "g"]
        g = relplot(
            x="x", y="y", hue="a", style="b", col="c",
            palette=palette, data=long_df
        )

        palette = dict(zip(long_df["a"].unique(), palette))
        grouped = long_df.groupby("c")
        for (_, grp_df), ax in zip(grouped, g.axes.flat):
            points = ax.collections[0]
            expected_hues = [palette[val] for val in grp_df["a"]]
            assert same_color(points.get_facecolors(), expected_hues)

    def test_relplot_sizes(self, long_df):

        sizes = [5, 12, 7]
        g = relplot(
            data=long_df,
            x="x", y="y", size="a", hue="b", col="c",
            sizes=sizes,
        )

        sizes = dict(zip(long_df["a"].unique(), sizes))
        grouped = long_df.groupby("c")
        for (_, grp_df), ax in zip(grouped, g.axes.flat):
            points = ax.collections[0]
            expected_sizes = [sizes[val] for val in grp_df["a"]]
            assert_array_equal(points.get_sizes(), expected_sizes)

    def test_relplot_styles(self, long_df):

        markers = ["o", "d", "s"]
        g = relplot(
            data=long_df,
            x="x", y="y", style="a", hue="b", col="c",
            markers=markers,
        )

        paths = []
        for m in markers:
            m = mpl.markers.MarkerStyle(m)
            paths.append(m.get_path().transformed(m.get_transform()))
        paths = dict(zip(long_df["a"].unique(), paths))

        grouped = long_df.groupby("c")
        for (_, grp_df), ax in zip(grouped, g.axes.flat):
            points = ax.collections[0]
            expected_paths = [paths[val] for val in grp_df["a"]]
            assert self.paths_equal(points.get_paths(), expected_paths)

    def test_relplot_weighted_estimator(self, long_df):

        g = relplot(data=long_df, x="a", y="y", weights="x", kind="line")
        ydata = g.ax.lines[0].get_ydata()
        for i, level in enumerate(categorical_order(long_df["a"])):
            pos_df = long_df[long_df["a"] == level]
            expected = np.average(pos_df["y"], weights=pos_df["x"])
            assert ydata[i] == pytest.approx(expected)

    def test_relplot_stringy_numerics(self, long_df):

        long_df["x_str"] = long_df["x"].astype(str)

        g = relplot(data=long_df, x="x", y="y", hue="x_str")
        points = g.ax.collections[0]
        xys = points.get_offsets()
        mask = np.ma.getmask(xys)
        assert not mask.any()
        assert_array_equal(xys, long_df[["x", "y"]])

        g = relplot(data=long_df, x="x", y="y", size="x_str")
        points = g.ax.collections[0]
        xys = points.get_offsets()
        mask = np.ma.getmask(xys)
        assert not mask.any()
        assert_array_equal(xys, long_df[["x", "y"]])

    def test_relplot_legend(self, long_df):

        g = relplot(data=long_df, x="x", y="y")
        assert g._legend is None

        g = relplot(data=long_df, x="x", y="y", hue="a")
        texts = [t.get_text() for t in g._legend.texts]
        expected_texts = long_df["a"].unique()
        assert_array_equal(texts, expected_texts)

        g = relplot(data=long_df, x="x", y="y", hue="s", size="s")
        texts = [t.get_text() for t in g._legend.texts]
        assert_array_equal(texts, np.sort(texts))

        g = relplot(data=long_df, x="x", y="y", hue="a", legend=False)
        assert g._legend is None

        palette = color_palette("deep", len(long_df["b"].unique()))
        a_like_b = dict(zip(long_df["a"].unique(), long_df["b"].unique()))
        long_df["a_like_b"] = long_df["a"].map(a_like_b)
        g = relplot(
            data=long_df,
            x="x", y="y", hue="b", style="a_like_b",
            palette=palette, kind="line", estimator=None,
        )
        lines = g._legend.get_lines()[1:]  # Chop off title dummy
        for line, color in zip(lines, palette):
            assert line.get_color() == color

    def test_relplot_unshared_axis_labels(self, long_df):

        col, row = "a", "b"
        g = relplot(
            data=long_df, x="x", y="y", col=col, row=row,
            facet_kws=dict(sharex=False, sharey=False),
        )

        for ax in g.axes[-1, :].flat:
            assert ax.get_xlabel() == "x"
        for ax in g.axes[:-1, :].flat:
            assert ax.get_xlabel() == ""
        for ax in g.axes[:, 0].flat:
            assert ax.get_ylabel() == "y"
        for ax in g.axes[:, 1:].flat:
            assert ax.get_ylabel() == ""

    def test_relplot_data(self, long_df):

        g = relplot(
            data=long_df.to_dict(orient="list"),
            x="x",
            y=long_df["y"].rename("y_var"),
            hue=long_df["a"].to_numpy(),
            col="c",
        )
        expected_cols = set(long_df.columns.to_list() + ["_hue_", "y_var"])
        assert set(g.data.columns) == expected_cols
        assert_array_equal(g.data["y_var"], long_df["y"])
        assert_array_equal(g.data["_hue_"], long_df["a"])

    def test_facet_variable_collision(self, long_df):

        # https://github.com/mwaskom/seaborn/issues/2488
        col_data = long_df["c"]
        long_df = long_df.assign(size=col_data)

        g = relplot(
            data=long_df,
            x="x", y="y", col="size",
        )
        assert g.axes.shape == (1, len(col_data.unique()))

    def test_relplot_scatter_unused_variables(self, long_df):

        with pytest.warns(UserWarning, match="The `units` parameter"):
            g = relplot(long_df, x="x", y="y", units="a")
        assert g.ax is not None

        with pytest.warns(UserWarning, match="The `weights` parameter"):
            g = relplot(long_df, x="x", y="y", weights="x")
        assert g.ax is not None

    def test_ax_kwarg_removal(self, long_df):

        f, ax = plt.subplots()
        with pytest.warns(UserWarning):
            g = relplot(data=long_df, x="x", y="y", ax=ax)
        assert len(ax.collections) == 0
        assert len(g.ax.collections) > 0

    def test_legend_has_no_offset(self, long_df):

        g = relplot(data=long_df, x="x", y="y", hue=long_df["z"] + 1e8)
        for text in g.legend.texts:
            assert float(text.get_text()) > 1e7

    def test_lineplot_2d_dashes(self, long_df):
        ax = lineplot(data=long_df[["x", "y"]], dashes=[(5, 5), (10, 10)])
        for line in ax.get_lines():
            assert line.is_dashed()

    def test_legend_attributes_hue(self, long_df):

        kws = {"s": 50, "linewidth": 1, "marker": "X"}
        g = relplot(long_df, x="x", y="y", hue="a", **kws)
        palette = color_palette()
        for i, pt in enumerate(get_legend_handles(g.legend)):
            assert same_color(pt.get_color(), palette[i])
            assert pt.get_markersize() == np.sqrt(kws["s"])
            assert pt.get_markeredgewidth() == kws["linewidth"]
            if not _version_predates(mpl, "3.7.0"):
                assert pt.get_marker() == kws["marker"]

    def test_legend_attributes_style(self, long_df):

        kws = {"s": 50, "linewidth": 1, "color": "r"}
        g = relplot(long_df, x="x", y="y", style="a", **kws)
        for pt in get_legend_handles(g.legend):
            assert pt.get_markersize() == np.sqrt(kws["s"])
            assert pt.get_markeredgewidth() == kws["linewidth"]
            assert same_color(pt.get_color(), "r")

    def test_legend_attributes_hue_and_style(self, long_df):

        kws = {"s": 50, "linewidth": 1}
        g = relplot(long_df, x="x", y="y", hue="a", style="b", **kws)
        for pt in get_legend_handles(g.legend):
            if pt.get_label() not in ["a", "b"]:
                assert pt.get_markersize() == np.sqrt(kws["s"])
                assert pt.get_markeredgewidth() == kws["linewidth"]


class TestLinePlotter(SharedAxesLevelTests, Helpers):

    func = staticmethod(lineplot)

    def get_last_color(self, ax):

        return to_rgba(ax.lines[-1].get_color())

    def test_legend_no_semantics(self, long_df):

        ax = lineplot(long_df, x="x", y="y")
        handles, _ = ax.get_legend_handles_labels()
        assert handles == []

    def test_legend_hue_categorical(self, long_df, levels):

        ax = lineplot(long_df, x="x", y="y", hue="a")
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        assert labels == levels["a"]
        assert colors == color_palette(n_colors=len(labels))

    def test_legend_hue_and_style_same(self, long_df, levels):

        ax = lineplot(long_df, x="x", y="y", hue="a", style="a", markers=True)
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        markers = [h.get_marker() for h in handles]
        assert labels == levels["a"]
        assert colors == color_palette(n_colors=len(labels))
        assert markers == unique_markers(len(labels))

    def test_legend_hue_and_style_diff(self, long_df, levels):

        ax = lineplot(long_df, x="x", y="y", hue="a", style="b", markers=True)
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        markers = [h.get_marker() for h in handles]
        expected_labels = ["a", *levels["a"], "b", *levels["b"]]
        expected_colors = [
            "w", *color_palette(n_colors=len(levels["a"])),
            "w", *[".2" for _ in levels["b"]],
        ]
        expected_markers = [
            "", *["None" for _ in levels["a"]]
            + [""] + unique_markers(len(levels["b"]))
        ]
        assert labels == expected_labels
        assert colors == expected_colors
        assert markers == expected_markers

    def test_legend_hue_and_size_same(self, long_df, levels):

        ax = lineplot(long_df, x="x", y="y", hue="a", size="a")
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        widths = [h.get_linewidth() for h in handles]
        assert labels == levels["a"]
        assert colors == color_palette(n_colors=len(levels["a"]))
        expected_widths = [
            w * mpl.rcParams["lines.linewidth"]
            for w in np.linspace(2, 0.5, len(levels["a"]))
        ]
        assert widths == expected_widths

    @pytest.mark.parametrize("var", ["hue", "size", "style"])
    def test_legend_numerical_full(self, long_df, var):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)

        ax = lineplot(x=x, y=y, **{var: z}, legend="full")
        _, labels = ax.get_legend_handles_labels()
        assert labels == [str(z_i) for z_i in sorted(set(z))]

    @pytest.mark.parametrize("var", ["hue", "size", "style"])
    def test_legend_numerical_brief(self, var):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)

        ax = lineplot(x=x, y=y, **{var: z}, legend="brief")
        _, labels = ax.get_legend_handles_labels()
        if var == "style":
            assert labels == [str(z_i) for z_i in sorted(set(z))]
        else:
            assert labels == ["0", "4", "8", "12", "16"]

    def test_legend_value_error(self, long_df):

        with pytest.raises(ValueError, match=r"`legend` must be"):
            lineplot(long_df, x="x", y="y", hue="a", legend="bad_value")

    @pytest.mark.parametrize("var", ["hue", "size"])
    def test_legend_log_norm(self, var):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)

        norm = mpl.colors.LogNorm()
        ax = lineplot(x=x, y=y, **{var: z + 1, f"{var}_norm": norm})
        _, labels = ax.get_legend_handles_labels()
        assert float(labels[1]) / float(labels[0]) == 10

    @pytest.mark.parametrize("var", ["hue", "size"])
    def test_legend_binary_var(self, var):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)

        ax = lineplot(x=x, y=y, hue=z % 2)
        _, labels = ax.get_legend_handles_labels()
        assert labels == ["0", "1"]

    @pytest.mark.parametrize("var", ["hue", "size"])
    def test_legend_binary_numberic_brief(self, long_df, var):

        ax = lineplot(long_df, x="x", y="y", **{var: "f"}, legend="brief")
        _, labels = ax.get_legend_handles_labels()
        expected_labels = ['0.20', '0.22', '0.24', '0.26', '0.28']
        assert labels == expected_labels

    def test_plot(self, long_df, repeated_df):

        f, ax = plt.subplots()

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
            sort=False,
            estimator=None
        )
        p.plot(ax, {})
        line, = ax.lines
        assert_array_equal(line.get_xdata(), long_df.x.to_numpy())
        assert_array_equal(line.get_ydata(), long_df.y.to_numpy())

        ax.clear()
        p.plot(ax, {"color": "k", "label": "test"})
        line, = ax.lines
        assert line.get_color() == "k"
        assert line.get_label() == "test"

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
            sort=True, estimator=None
        )

        ax.clear()
        p.plot(ax, {})
        line, = ax.lines
        sorted_data = long_df.sort_values(["x", "y"])
        assert_array_equal(line.get_xdata(), sorted_data.x.to_numpy())
        assert_array_equal(line.get_ydata(), sorted_data.y.to_numpy())

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
        )

        ax.clear()
        p.plot(ax, {})
        assert len(ax.lines) == len(p._hue_map.levels)
        for line, level in zip(ax.lines, p._hue_map.levels):
            assert line.get_color() == p._hue_map(level)

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", size="a"),
        )

        ax.clear()
        p.plot(ax, {})
        assert len(ax.lines) == len(p._size_map.levels)
        for line, level in zip(ax.lines, p._size_map.levels):
            assert line.get_linewidth() == p._size_map(level)

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a", style="a"),
        )
        p.map_style(markers=True)

        ax.clear()
        p.plot(ax, {})
        assert len(ax.lines) == len(p._hue_map.levels)
        assert len(ax.lines) == len(p._style_map.levels)
        for line, level in zip(ax.lines, p._hue_map.levels):
            assert line.get_color() == p._hue_map(level)
            assert line.get_marker() == p._style_map(level, "marker")

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a", style="b"),
        )
        p.map_style(markers=True)

        ax.clear()
        p.plot(ax, {})
        levels = product(p._hue_map.levels, p._style_map.levels)
        expected_line_count = len(p._hue_map.levels) * len(p._style_map.levels)
        assert len(ax.lines) == expected_line_count
        for line, (hue, style) in zip(ax.lines, levels):
            assert line.get_color() == p._hue_map(hue)
            assert line.get_marker() == p._style_map(style, "marker")

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
            estimator="mean", err_style="band", errorbar="sd", sort=True
        )

        ax.clear()
        p.plot(ax, {})
        line, = ax.lines
        expected_data = long_df.groupby("x").y.mean()
        assert_array_equal(line.get_xdata(), expected_data.index.to_numpy())
        assert np.allclose(line.get_ydata(), expected_data.to_numpy())
        assert len(ax.collections) == 1

        # Test that nans do not propagate to means or CIs

        p = _LinePlotter(
            variables=dict(
                x=[1, 1, 1, 2, 2, 2, 3, 3, 3],
                y=[1, 2, 3, 3, np.nan, 5, 4, 5, 6],
            ),
            estimator="mean", err_style="band", errorbar="ci", n_boot=100, sort=True,
        )
        ax.clear()
        p.plot(ax, {})
        line, = ax.lines
        assert line.get_xdata().tolist() == [1, 2, 3]
        err_band = ax.collections[0].get_paths()
        assert len(err_band) == 1
        assert len(err_band[0].vertices) == 9

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
            estimator="mean", err_style="band", errorbar="sd"
        )

        ax.clear()
        p.plot(ax, {})
        assert len(ax.lines) == len(ax.collections) == len(p._hue_map.levels)
        for c in ax.collections:
            assert isinstance(c, mpl.collections.PolyCollection)

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
            estimator="mean", err_style="bars", errorbar="sd"
        )

        ax.clear()
        p.plot(ax, {})
        n_lines = len(ax.lines)
        assert n_lines / 2 == len(ax.collections) == len(p._hue_map.levels)
        assert len(ax.collections) == len(p._hue_map.levels)
        for c in ax.collections:
            assert isinstance(c, mpl.collections.LineCollection)

        p = _LinePlotter(
            data=repeated_df,
            variables=dict(x="x", y="y", units="u"),
            estimator=None
        )

        ax.clear()
        p.plot(ax, {})
        n_units = len(repeated_df["u"].unique())
        assert len(ax.lines) == n_units

        p = _LinePlotter(
            data=repeated_df,
            variables=dict(x="x", y="y", hue="a", units="u"),
            estimator=None
        )

        ax.clear()
        p.plot(ax, {})
        n_units *= len(repeated_df["a"].unique())
        assert len(ax.lines) == n_units

        p.estimator = "mean"
        with pytest.raises(ValueError):
            p.plot(ax, {})

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
            err_style="band", err_kws={"alpha": .5},
        )

        ax.clear()
        p.plot(ax, {})
        for band in ax.collections:
            assert band.get_alpha() == .5

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a"),
            err_style="bars", err_kws={"elinewidth": 2},
        )

        ax.clear()
        p.plot(ax, {})
        for lines in ax.collections:
            assert lines.get_linestyles() == 2

        p.err_style = "invalid"
        with pytest.raises(ValueError):
            p.plot(ax, {})

        x_str = long_df["x"].astype(str)
        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue=x_str),
        )
        ax.clear()
        p.plot(ax, {})

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y", size=x_str),
        )
        ax.clear()
        p.plot(ax, {})

    def test_weights(self, long_df):

        ax = lineplot(long_df, x="a", y="y", weights="x")
        vals = ax.lines[0].get_ydata()
        for i, level in enumerate(categorical_order(long_df["a"])):
            pos_df = long_df[long_df["a"] == level]
            expected = np.average(pos_df["y"], weights=pos_df["x"])
            assert vals[i] == pytest.approx(expected)

    def test_non_aggregated_data(self):

        x = [1, 2, 3, 4]
        y = [2, 4, 6, 8]
        ax = lineplot(x=x, y=y)
        line, = ax.lines
        assert_array_equal(line.get_xdata(), x)
        assert_array_equal(line.get_ydata(), y)

    def test_orient(self, long_df):

        long_df = long_df.drop("x", axis=1).rename(columns={"s": "y", "y": "x"})

        ax1 = plt.figure().subplots()
        lineplot(data=long_df, x="x", y="y", orient="y", errorbar="sd")
        assert len(ax1.lines) == len(ax1.collections)
        line, = ax1.lines
        expected = long_df.groupby("y").agg({"x": "mean"}).reset_index()
        assert_array_almost_equal(line.get_xdata(), expected["x"])
        assert_array_almost_equal(line.get_ydata(), expected["y"])
        ribbon_y = ax1.collections[0].get_paths()[0].vertices[:, 1]
        assert_array_equal(np.unique(ribbon_y), long_df["y"].sort_values().unique())

        ax2 = plt.figure().subplots()
        lineplot(
            data=long_df, x="x", y="y", orient="y", errorbar="sd", err_style="bars"
        )
        segments = ax2.collections[0].get_segments()
        for i, val in enumerate(sorted(long_df["y"].unique())):
            assert (segments[i][:, 1] == val).all()

        with pytest.raises(ValueError, match="`orient` must be either 'x' or 'y'"):
            lineplot(long_df, x="y", y="x", orient="bad")

    def test_log_scale(self):

        f, ax = plt.subplots()
        ax.set_xscale("log")

        x = [1, 10, 100]
        y = [1, 2, 3]

        lineplot(x=x, y=y)
        line = ax.lines[0]
        assert_array_equal(line.get_xdata(), x)
        assert_array_equal(line.get_ydata(), y)

        f, ax = plt.subplots()
        ax.set_xscale("log")
        ax.set_yscale("log")

        x = [1, 1, 2, 2]
        y = [1, 10, 1, 100]

        lineplot(x=x, y=y, err_style="bars", errorbar=("pi", 100))
        line = ax.lines[0]
        assert line.get_ydata()[1] == 10

        ebars = ax.collections[0].get_segments()
        assert_array_equal(ebars[0][:, 1], y[:2])
        assert_array_equal(ebars[1][:, 1], y[2:])

    def test_axis_labels(self, long_df):

        f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)

        p = _LinePlotter(
            data=long_df,
            variables=dict(x="x", y="y"),
        )

        p.plot(ax1, {})
        assert ax1.get_xlabel() == "x"
        assert ax1.get_ylabel() == "y"

        p.plot(ax2, {})
        assert ax2.get_xlabel() == "x"
        assert ax2.get_ylabel() == "y"
        assert not ax2.yaxis.label.get_visible()

    def test_matplotlib_kwargs(self, long_df):

        kws = {
            "linestyle": "--",
            "linewidth": 3,
            "color": (1, .5, .2),
            "markeredgecolor": (.2, .5, .2),
            "markeredgewidth": 1,
        }
        ax = lineplot(data=long_df, x="x", y="y", **kws)

        line, *_ = ax.lines
        for key, val in kws.items():
            plot_val = getattr(line, f"get_{key}")()
            assert plot_val == val

    def test_nonmapped_dashes(self):

        ax = lineplot(x=[1, 2], y=[1, 2], dashes=(2, 1))
        line = ax.lines[0]
        # Not a great test, but lines don't expose the dash style publicly
        assert line.get_linestyle() == "--"

    def test_lineplot_axes(self, wide_df):

        f1, ax1 = plt.subplots()
        f2, ax2 = plt.subplots()

        ax = lineplot(data=wide_df)
        assert ax is ax2

        ax = lineplot(data=wide_df, ax=ax1)
        assert ax is ax1

    def test_legend_attributes_with_hue(self, long_df):

        kws = {"marker": "o", "linewidth": 3}
        ax = lineplot(long_df, x="x", y="y", hue="a", **kws)
        palette = color_palette()
        for i, line in enumerate(get_legend_handles(ax.get_legend())):
            assert same_color(line.get_color(), palette[i])
            assert line.get_linewidth() == kws["linewidth"]
            if not _version_predates(mpl, "3.7.0"):
                assert line.get_marker() == kws["marker"]

    def test_legend_attributes_with_style(self, long_df):

        kws = {"color": "r", "marker": "o", "linewidth": 3}
        ax = lineplot(long_df, x="x", y="y", style="a", **kws)
        for line in get_legend_handles(ax.get_legend()):
            assert same_color(line.get_color(), kws["color"])
            if not _version_predates(mpl, "3.7.0"):
                assert line.get_marker() == kws["marker"]
            assert line.get_linewidth() == kws["linewidth"]

    def test_legend_attributes_with_hue_and_style(self, long_df):

        kws = {"marker": "o", "linewidth": 3}
        ax = lineplot(long_df, x="x", y="y", hue="a", style="b", **kws)
        for line in get_legend_handles(ax.get_legend()):
            if line.get_label() not in ["a", "b"]:
                if not _version_predates(mpl, "3.7.0"):
                    assert line.get_marker() == kws["marker"]
                assert line.get_linewidth() == kws["linewidth"]

    def test_lineplot_vs_relplot(self, long_df, long_semantics):

        ax = lineplot(data=long_df, legend=False, **long_semantics)
        g = relplot(data=long_df, kind="line", legend=False, **long_semantics)

        lin_lines = ax.lines
        rel_lines = g.ax.lines

        for l1, l2 in zip(lin_lines, rel_lines):
            assert_array_equal(l1.get_xydata(), l2.get_xydata())
            assert same_color(l1.get_color(), l2.get_color())
            assert l1.get_linewidth() == l2.get_linewidth()
            assert l1.get_linestyle() == l2.get_linestyle()

    def test_lineplot_smoke(
        self,
        wide_df, wide_array,
        wide_list_of_series, wide_list_of_arrays, wide_list_of_lists,
        flat_array, flat_series, flat_list,
        long_df, null_df, object_df
    ):

        f, ax = plt.subplots()

        lineplot(x=[], y=[])
        ax.clear()

        lineplot(data=wide_df)
        ax.clear()

        lineplot(data=wide_array)
        ax.clear()

        lineplot(data=wide_list_of_series)
        ax.clear()

        lineplot(data=wide_list_of_arrays)
        ax.clear()

        lineplot(data=wide_list_of_lists)
        ax.clear()

        lineplot(data=flat_series)
        ax.clear()

        lineplot(data=flat_array)
        ax.clear()

        lineplot(data=flat_list)
        ax.clear()

        lineplot(x="x", y="y", data=long_df)
        ax.clear()

        lineplot(x=long_df.x, y=long_df.y)
        ax.clear()

        lineplot(x=long_df.x, y="y", data=long_df)
        ax.clear()

        lineplot(x="x", y=long_df.y.to_numpy(), data=long_df)
        ax.clear()

        lineplot(x="x", y="t", data=long_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", data=long_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", style="a", data=long_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", style="b", data=long_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", style="a", data=null_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", style="b", data=null_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", size="a", data=long_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", size="s", data=long_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", size="a", data=null_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", size="s", data=null_df)
        ax.clear()

        lineplot(x="x", y="y", hue="f", data=object_df)
        ax.clear()

        lineplot(x="x", y="y", hue="c", size="f", data=object_df)
        ax.clear()

        lineplot(x="x", y="y", hue="f", size="s", data=object_df)
        ax.clear()

        lineplot(x="x", y="y", hue="a", data=long_df.iloc[:0])
        ax.clear()

    def test_ci_deprecation(self, long_df):

        axs = plt.figure().subplots(2)
        lineplot(data=long_df, x="x", y="y", errorbar=("ci", 95), seed=0, ax=axs[0])
        with pytest.warns(FutureWarning, match="\n\nThe `ci` parameter is deprecated"):
            lineplot(data=long_df, x="x", y="y", ci=95, seed=0, ax=axs[1])
        assert_plots_equal(*axs)

        axs = plt.figure().subplots(2)
        lineplot(data=long_df, x="x", y="y", errorbar="sd", ax=axs[0])
        with pytest.warns(FutureWarning, match="\n\nThe `ci` parameter is deprecated"):
            lineplot(data=long_df, x="x", y="y", ci="sd", ax=axs[1])
        assert_plots_equal(*axs)


class TestScatterPlotter(SharedAxesLevelTests, Helpers):

    func = staticmethod(scatterplot)

    def get_last_color(self, ax):

        colors = ax.collections[-1].get_facecolors()
        unique_colors = np.unique(colors, axis=0)
        assert len(unique_colors) == 1
        return to_rgba(unique_colors.squeeze())

    def test_color(self, long_df):

        super().test_color(long_df)

        ax = plt.figure().subplots()
        self.func(data=long_df, x="x", y="y", facecolor="C5", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C5")

        ax = plt.figure().subplots()
        self.func(data=long_df, x="x", y="y", facecolors="C6", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C6")

        ax = plt.figure().subplots()
        self.func(data=long_df, x="x", y="y", fc="C4", ax=ax)
        assert self.get_last_color(ax) == to_rgba("C4")

    def test_legend_no_semantics(self, long_df):

        ax = scatterplot(long_df, x="x", y="y")
        handles, _ = ax.get_legend_handles_labels()
        assert not handles

    def test_legend_hue(self, long_df):

        ax = scatterplot(long_df, x="x", y="y", hue="a")
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        expected_colors = color_palette(n_colors=len(handles))
        assert same_color(colors, expected_colors)
        assert labels == categorical_order(long_df["a"])

    def test_legend_hue_style_same(self, long_df):

        ax = scatterplot(long_df, x="x", y="y", hue="a", style="a")
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        expected_colors = color_palette(n_colors=len(labels))
        markers = [h.get_marker() for h in handles]
        expected_markers = unique_markers(len(handles))
        assert same_color(colors, expected_colors)
        assert markers == expected_markers
        assert labels == categorical_order(long_df["a"])

    def test_legend_hue_style_different(self, long_df):

        ax = scatterplot(long_df, x="x", y="y", hue="a", style="b")
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        expected_colors = [
            "w", *color_palette(n_colors=long_df["a"].nunique()),
            "w", *[".2" for _ in long_df["b"].unique()],
        ]
        markers = [h.get_marker() for h in handles]
        expected_markers = [
            "", *["o" for _ in long_df["a"].unique()],
            "", *unique_markers(long_df["b"].nunique()),
        ]
        assert same_color(colors, expected_colors)
        assert markers == expected_markers
        assert labels == [
            "a", *categorical_order(long_df["a"]),
            "b", *categorical_order(long_df["b"]),
        ]

    def test_legend_data_hue_size_same(self, long_df):

        ax = scatterplot(long_df, x="x", y="y", hue="a", size="a")
        handles, labels = ax.get_legend_handles_labels()
        colors = [h.get_color() for h in handles]
        expected_colors = color_palette(n_colors=len(labels))
        sizes = [h.get_markersize() for h in handles]
        ms = mpl.rcParams["lines.markersize"] ** 2
        expected_sizes = np.sqrt(
            [ms * scl for scl in np.linspace(2, 0.5, len(handles))]
        ).tolist()
        assert same_color(colors, expected_colors)
        assert sizes == expected_sizes
        assert labels == categorical_order(long_df["a"])
        assert ax.get_legend().get_title().get_text() == "a"

    def test_legend_size_numeric_list(self, long_df):

        size_list = [10, 100, 200]
        ax = scatterplot(long_df, x="x", y="y", size="s", sizes=size_list)
        handles, labels = ax.get_legend_handles_labels()
        sizes = [h.get_markersize() for h in handles]
        expected_sizes = list(np.sqrt(size_list))
        assert sizes == expected_sizes
        assert labels == list(map(str, categorical_order(long_df["s"])))
        assert ax.get_legend().get_title().get_text() == "s"

    def test_legend_size_numeric_dict(self, long_df):

        size_dict = {2: 10, 4: 100, 8: 200}
        ax = scatterplot(long_df, x="x", y="y", size="s", sizes=size_dict)
        handles, labels = ax.get_legend_handles_labels()
        sizes = [h.get_markersize() for h in handles]
        order = categorical_order(long_df["s"])
        expected_sizes = [np.sqrt(size_dict[k]) for k in order]
        assert sizes == expected_sizes
        assert labels == list(map(str, order))
        assert ax.get_legend().get_title().get_text() == "s"

    def test_legend_numeric_hue_full(self):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)
        ax = scatterplot(x=x, y=y, hue=z, legend="full")
        _, labels = ax.get_legend_handles_labels()
        assert labels == [str(z_i) for z_i in sorted(set(z))]
        assert ax.get_legend().get_title().get_text() == ""

    def test_legend_numeric_hue_brief(self):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)
        ax = scatterplot(x=x, y=y, hue=z, legend="brief")
        _, labels = ax.get_legend_handles_labels()
        assert len(labels) < len(set(z))

    def test_legend_numeric_size_full(self):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)
        ax = scatterplot(x=x, y=y, size=z, legend="full")
        _, labels = ax.get_legend_handles_labels()
        assert labels == [str(z_i) for z_i in sorted(set(z))]

    def test_legend_numeric_size_brief(self):

        x, y = np.random.randn(2, 40)
        z = np.tile(np.arange(20), 2)
        ax = scatterplot(x=x, y=y, size=z, legend="brief")
        _, labels = ax.get_legend_handles_labels()
        assert len(labels) < len(set(z))

    def test_legend_attributes_hue(self, long_df):

        kws = {"s": 50, "linewidth": 1, "marker": "X"}
        ax = scatterplot(long_df, x="x", y="y", hue="a", **kws)
        palette = color_palette()
        for i, pt in enumerate(get_legend_handles(ax.get_legend())):
            assert same_color(pt.get_color(), palette[i])
            assert pt.get_markersize() == np.sqrt(kws["s"])
            assert pt.get_markeredgewidth() == kws["linewidth"]
            if not _version_predates(mpl, "3.7.0"):
                # This attribute is empty on older matplotlibs
                # but the legend looks correct so I assume it is a bug
                assert pt.get_marker() == kws["marker"]

    def test_legend_attributes_style(self, long_df):

        kws = {"s": 50, "linewidth": 1, "color": "r"}
        ax = scatterplot(long_df, x="x", y="y", style="a", **kws)
        for pt in get_legend_handles(ax.get_legend()):
            assert pt.get_markersize() == np.sqrt(kws["s"])
            assert pt.get_markeredgewidth() == kws["linewidth"]
            assert same_color(pt.get_color(), "r")

    def test_legend_attributes_hue_and_style(self, long_df):

        kws = {"s": 50, "linewidth": 1}
        ax = scatterplot(long_df, x="x", y="y", hue="a", style="b", **kws)
        for pt in get_legend_handles(ax.get_legend()):
            if pt.get_label() not in ["a", "b"]:
                assert pt.get_markersize() == np.sqrt(kws["s"])
                assert pt.get_markeredgewidth() == kws["linewidth"]

    def test_legend_value_error(self, long_df):

        with pytest.raises(ValueError, match=r"`legend` must be"):
            scatterplot(long_df, x="x", y="y", hue="a", legend="bad_value")

    def test_plot(self, long_df, repeated_df):

        f, ax = plt.subplots()

        p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y"))

        p.plot(ax, {})
        points = ax.collections[0]
        assert_array_equal(points.get_offsets(), long_df[["x", "y"]].to_numpy())

        ax.clear()
        p.plot(ax, {"color": "k", "label": "test"})
        points = ax.collections[0]
        assert same_color(points.get_facecolor(), "k")
        assert points.get_label() == "test"

        p = _ScatterPlotter(
            data=long_df, variables=dict(x="x", y="y", hue="a")
        )

        ax.clear()
        p.plot(ax, {})
        points = ax.collections[0]
        expected_colors = p._hue_map(p.plot_data["hue"])
        assert same_color(points.get_facecolors(), expected_colors)

        p = _ScatterPlotter(
            data=long_df,
            variables=dict(x="x", y="y", style="c"),
        )
        p.map_style(markers=["+", "x"])

        ax.clear()
        color = (1, .3, .8)
        p.plot(ax, {"color": color})
        points = ax.collections[0]
        assert same_color(points.get_edgecolors(), [color])

        p = _ScatterPlotter(
            data=long_df, variables=dict(x="x", y="y", size="a"),
        )

        ax.clear()
        p.plot(ax, {})
        points = ax.collections[0]
        expected_sizes = p._size_map(p.plot_data["size"])
        assert_array_equal(points.get_sizes(), expected_sizes)

        p = _ScatterPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a", style="a"),
        )
        p.map_style(markers=True)

        ax.clear()
        p.plot(ax, {})
        points = ax.collections[0]
        expected_colors = p._hue_map(p.plot_data["hue"])
        expected_paths = p._style_map(p.plot_data["style"], "path")
        assert same_color(points.get_facecolors(), expected_colors)
        assert self.paths_equal(points.get_paths(), expected_paths)

        p = _ScatterPlotter(
            data=long_df,
            variables=dict(x="x", y="y", hue="a", style="b"),
        )
        p.map_style(markers=True)

        ax.clear()
        p.plot(ax, {})
        points = ax.collections[0]
        expected_colors = p._hue_map(p.plot_data["hue"])
        expected_paths = p._style_map(p.plot_data["style"], "path")
        assert same_color(points.get_facecolors(), expected_colors)
        assert self.paths_equal(points.get_paths(), expected_paths)

        x_str = long_df["x"].astype(str)
        p = _ScatterPlotter(
            data=long_df, variables=dict(x="x", y="y", hue=x_str),
        )
        ax.clear()
        p.plot(ax, {})

        p = _ScatterPlotter(
            data=long_df, variables=dict(x="x", y="y", size=x_str),
        )
        ax.clear()
        p.plot(ax, {})

    def test_axis_labels(self, long_df):

        f, (ax1, ax2) = plt.subplots(1, 2, sharey=True)

        p = _ScatterPlotter(data=long_df, variables=dict(x="x", y="y"))

        p.plot(ax1, {})
        assert ax1.get_xlabel() == "x"
        assert ax1.get_ylabel() == "y"

        p.plot(ax2, {})
        assert ax2.get_xlabel() == "x"
        assert ax2.get_ylabel() == "y"
        assert not ax2.yaxis.label.get_visible()

    def test_scatterplot_axes(self, wide_df):

        f1, ax1 = plt.subplots()
        f2, ax2 = plt.subplots()

        ax = scatterplot(data=wide_df)
        assert ax is ax2

        ax = scatterplot(data=wide_df, ax=ax1)
        assert ax is ax1

    def test_literal_attribute_vectors(self):

        f, ax = plt.subplots()

        x = y = [1, 2, 3]
        s = [5, 10, 15]
        c = [(1, 1, 0, 1), (1, 0, 1, .5), (.5, 1, 0, 1)]

        scatterplot(x=x, y=y, c=c, s=s, ax=ax)

        points, = ax.collections

        assert_array_equal(points.get_sizes().squeeze(), s)
        assert_array_equal(points.get_facecolors(), c)

    def test_supplied_color_array(self, long_df):

        cmap = get_colormap("Blues")
        norm = mpl.colors.Normalize()
        colors = cmap(norm(long_df["y"].to_numpy()))

        keys = ["c", "fc", "facecolor", "facecolors"]

        for key in keys:

            ax = plt.figure().subplots()
            scatterplot(data=long_df, x="x", y="y", **{key: colors})
            _draw_figure(ax.figure)
            assert_array_equal(ax.collections[0].get_facecolors(), colors)

        ax = plt.figure().subplots()
        scatterplot(data=long_df, x="x", y="y", c=long_df["y"], cmap=cmap)
        _draw_figure(ax.figure)
        assert_array_equal(ax.collections[0].get_facecolors(), colors)

    def test_hue_order(self, long_df):

        order = categorical_order(long_df["a"])
        unused = order.pop()

        ax = scatterplot(data=long_df, x="x", y="y", hue="a", hue_order=order)
        points = ax.collections[0]
        assert (points.get_facecolors()[long_df["a"] == unused] == 0).all()
        assert [t.get_text() for t in ax.legend_.texts] == order

    def test_linewidths(self, long_df):

        f, ax = plt.subplots()

        scatterplot(data=long_df, x="x", y="y", s=10)
        scatterplot(data=long_df, x="x", y="y", s=20)
        points1, points2 = ax.collections
        assert (
            points1.get_linewidths().item() < points2.get_linewidths().item()
        )

        ax.clear()
        scatterplot(data=long_df, x="x", y="y", s=long_df["x"])
        scatterplot(data=long_df, x="x", y="y", s=long_df["x"] * 2)
        points1, points2 = ax.collections
        assert (
            points1.get_linewidths().item() < points2.get_linewidths().item()
        )

        ax.clear()
        lw = 2
        scatterplot(data=long_df, x="x", y="y", linewidth=lw)
        assert ax.collections[0].get_linewidths().item() == lw

    def test_size_norm_extrapolation(self):

        # https://github.com/mwaskom/seaborn/issues/2539
        x = np.arange(0, 20, 2)
        f, axs = plt.subplots(1, 2, sharex=True, sharey=True)

        slc = 5
        kws = dict(sizes=(50, 200), size_norm=(0, x.max()), legend="brief")

        scatterplot(x=x, y=x, size=x, ax=axs[0], **kws)
        scatterplot(x=x[:slc], y=x[:slc], size=x[:slc], ax=axs[1], **kws)

        assert np.allclose(
            axs[0].collections[0].get_sizes()[:slc],
            axs[1].collections[0].get_sizes()
        )

        legends = [ax.legend_ for ax in axs]
        legend_data = [
            {
                label.get_text(): handle.get_markersize()
                for label, handle in zip(legend.get_texts(), get_legend_handles(legend))
            } for legend in legends
        ]

        for key in set(legend_data[0]) & set(legend_data[1]):
            if key == "y":
                # At some point (circa 3.0) matplotlib auto-added pandas series
                # with a valid name into the legend, which messes up this test.
                # I can't track down when that was added (or removed), so let's
                # just anticipate and ignore it here.
                continue
            assert legend_data[0][key] == legend_data[1][key]

    def test_datetime_scale(self, long_df):

        ax = scatterplot(data=long_df, x="t", y="y")
        # Check that we avoid weird matplotlib default auto scaling
        # https://github.com/matplotlib/matplotlib/issues/17586
        ax.get_xlim()[0] > ax.xaxis.convert_units(np.datetime64("2002-01-01"))

    def test_unfilled_marker_edgecolor_warning(self, long_df):  # GH2636

        with warnings.catch_warnings():
            warnings.simplefilter("error")
            scatterplot(data=long_df, x="x", y="y", marker="+")

    def test_short_form_kwargs(self, long_df):

        ax = scatterplot(data=long_df, x="x", y="y", ec="g")
        pts = ax.collections[0]
        assert same_color(pts.get_edgecolors().squeeze(), "g")

    def test_scatterplot_vs_relplot(self, long_df, long_semantics):

        ax = scatterplot(data=long_df, **long_semantics)
        g = relplot(data=long_df, kind="scatter", **long_semantics)

        for s_pts, r_pts in zip(ax.collections, g.ax.collections):

            assert_array_equal(s_pts.get_offsets(), r_pts.get_offsets())
            assert_array_equal(s_pts.get_sizes(), r_pts.get_sizes())
            assert_array_equal(s_pts.get_facecolors(), r_pts.get_facecolors())
            assert self.paths_equal(s_pts.get_paths(), r_pts.get_paths())

    def test_scatterplot_smoke(
        self,
        wide_df, wide_array,
        flat_series, flat_array, flat_list,
        wide_list_of_series, wide_list_of_arrays, wide_list_of_lists,
        long_df, null_df, object_df
    ):

        f, ax = plt.subplots()

        scatterplot(x=[], y=[])
        ax.clear()

        scatterplot(data=wide_df)
        ax.clear()

        scatterplot(data=wide_array)
        ax.clear()

        scatterplot(data=wide_list_of_series)
        ax.clear()

        scatterplot(data=wide_list_of_arrays)
        ax.clear()

        scatterplot(data=wide_list_of_lists)
        ax.clear()

        scatterplot(data=flat_series)
        ax.clear()

        scatterplot(data=flat_array)
        ax.clear()

        scatterplot(data=flat_list)
        ax.clear()

        scatterplot(x="x", y="y", data=long_df)
        ax.clear()

        scatterplot(x=long_df.x, y=long_df.y)
        ax.clear()

        scatterplot(x=long_df.x, y="y", data=long_df)
        ax.clear()

        scatterplot(x="x", y=long_df.y.to_numpy(), data=long_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", data=long_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", style="a", data=long_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", style="b", data=long_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", style="a", data=null_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", style="b", data=null_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", size="a", data=long_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", size="s", data=long_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", size="a", data=null_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="a", size="s", data=null_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="f", data=object_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="c", size="f", data=object_df)
        ax.clear()

        scatterplot(x="x", y="y", hue="f", size="s", data=object_df)
        ax.clear()


================================================
FILE: tests/test_statistics.py
================================================
import numpy as np
import pandas as pd

try:
    import statsmodels.distributions as smdist
except ImportError:
    smdist = None

import pytest
from numpy.testing import assert_array_equal, assert_array_almost_equal

from seaborn._statistics import (
    KDE,
    Histogram,
    ECDF,
    EstimateAggregator,
    LetterValues,
    WeightedAggregator,
    _validate_errorbar_arg,
    _no_scipy,
)


class DistributionFixtures:

    @pytest.fixture
    def x(self, rng):
        return rng.normal(0, 1, 100)

    @pytest.fixture
    def x2(self, rng):
        return rng.normal(0, 1, 742)  # random value to avoid edge cases

    @pytest.fixture
    def y(self, rng):
        return rng.normal(0, 5, 100)

    @pytest.fixture
    def weights(self, rng):
        return rng.uniform(0, 5, 100)


class TestKDE:

    def integrate(self, y, x):
        y = np.asarray(y)
        x = np.asarray(x)
        dx = np.diff(x)
        return (dx * y[:-1] + dx * y[1:]).sum() / 2

    def test_gridsize(self, rng):

        x = rng.normal(0, 3, 1000)

        n = 200
        kde = KDE(gridsize=n)
        density, support = kde(x)
        assert density.size == n
        assert support.size == n

    def test_cut(self, rng):

        x = rng.normal(0, 3, 1000)

        kde = KDE(cut=0)
        _, support = kde(x)
        assert support.min() == x.min()
        assert support.max() == x.max()

        cut = 2
        bw_scale = .5
        bw = x.std() * bw_scale
        kde = KDE(cut=cut, bw_method=bw_scale, gridsize=1000)
        _, support = kde(x)
        assert support.min() == pytest.approx(x.min() - bw * cut, abs=1e-2)
        assert support.max() == pytest.approx(x.max() + bw * cut, abs=1e-2)

    def test_clip(self, rng):

        x = rng.normal(0, 3, 100)
        clip = -1, 1
        kde = KDE(clip=clip)
        _, support = kde(x)

        assert support.min() >= clip[0]
        assert support.max() <= clip[1]

    def test_density_normalization(self, rng):

        x = rng.normal(0, 3, 1000)
        kde = KDE()
        density, support = kde(x)
        assert self.integrate(density, support) == pytest.approx(1, abs=1e-5)

    @pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
    def test_cumulative(self, rng):

        x = rng.normal(0, 3, 1000)
        kde = KDE(cumulative=True)
        density, _ = kde(x)
        assert density[0] == pytest.approx(0, abs=1e-5)
        assert density[-1] == pytest.approx(1, abs=1e-5)

    def test_cached_support(self, rng):

        x = rng.normal(0, 3, 100)
        kde = KDE()
        kde.define_support(x)
        _, support = kde(x[(x > -1) & (x < 1)])
        assert_array_equal(support, kde.support)

    def test_bw_method(self, rng):

        x = rng.normal(0, 3, 100)
        kde1 = KDE(bw_method=.2)
        kde2 = KDE(bw_method=2)

        d1, _ = kde1(x)
        d2, _ = kde2(x)

        assert np.abs(np.diff(d1)).mean() > np.abs(np.diff(d2)).mean()

    def test_bw_adjust(self, rng):

        x = rng.normal(0, 3, 100)
        kde1 = KDE(bw_adjust=.2)
        kde2 = KDE(bw_adjust=2)

        d1, _ = kde1(x)
        d2, _ = kde2(x)

        assert np.abs(np.diff(d1)).mean() > np.abs(np.diff(d2)).mean()

    def test_bivariate_grid(self, rng):

        n = 100
        x, y = rng.normal(0, 3, (2, 50))
        kde = KDE(gridsize=n)
        density, (xx, yy) = kde(x, y)

        assert density.shape == (n, n)
        assert xx.size == n
        assert yy.size == n

    def test_bivariate_normalization(self, rng):

        x, y = rng.normal(0, 3, (2, 50))
        kde = KDE(gridsize=100)
        density, (xx, yy) = kde(x, y)

        dx = xx[1] - xx[0]
        dy = yy[1] - yy[0]

        total = density.sum() * (dx * dy)
        assert total == pytest.approx(1, abs=1e-2)

    @pytest.mark.skipif(_no_scipy, reason="Test requires scipy")
    def test_bivariate_cumulative(self, rng):

        x, y = rng.normal(0, 3, (2, 50))
        kde = KDE(gridsize=100, cumulative=True)
        density, _ = kde(x, y)

        assert density[0, 0] == pytest.approx(0, abs=1e-2)
        assert density[-1, -1] == pytest.approx(1, abs=1e-2)


class TestHistogram(DistributionFixtures):

    def test_string_bins(self, x):

        h = Histogram(bins="sqrt")
        bin_kws = h.define_bin_params(x)
        assert bin_kws["range"] == (x.min(), x.max())
        assert bin_kws["bins"] == int(np.sqrt(len(x)))

    def test_int_bins(self, x):

        n = 24
        h = Histogram(bins=n)
        bin_kws = h.define_bin_params(x)
        assert bin_kws["range"] == (x.min(), x.max())
        assert bin_kws["bins"] == n

    def test_array_bins(self, x):

        bins = [-3, -2, 1, 2, 3]
        h = Histogram(bins=bins)
        bin_kws = h.define_bin_params(x)
        assert_array_equal(bin_kws["bins"], bins)

    def test_bivariate_string_bins(self, x, y):

        s1, s2 = "sqrt", "fd"

        h = Histogram(bins=s1)
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert_array_equal(e1, np.histogram_bin_edges(x, s1))
        assert_array_equal(e2, np.histogram_bin_edges(y, s1))

        h = Histogram(bins=(s1, s2))
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert_array_equal(e1, np.histogram_bin_edges(x, s1))
        assert_array_equal(e2, np.histogram_bin_edges(y, s2))

    def test_bivariate_int_bins(self, x, y):

        b1, b2 = 5, 10

        h = Histogram(bins=b1)
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert len(e1) == b1 + 1
        assert len(e2) == b1 + 1

        h = Histogram(bins=(b1, b2))
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert len(e1) == b1 + 1
        assert len(e2) == b2 + 1

    def test_bivariate_array_bins(self, x, y):

        b1 = [-3, -2, 1, 2, 3]
        b2 = [-5, -2, 3, 6]

        h = Histogram(bins=b1)
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert_array_equal(e1, b1)
        assert_array_equal(e2, b1)

        h = Histogram(bins=(b1, b2))
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert_array_equal(e1, b1)
        assert_array_equal(e2, b2)

    def test_binwidth(self, x):

        binwidth = .5
        h = Histogram(binwidth=binwidth)
        bin_kws = h.define_bin_params(x)
        n_bins = bin_kws["bins"]
        left, right = bin_kws["range"]
        assert (right - left) / n_bins == pytest.approx(binwidth)

    def test_bivariate_binwidth(self, x, y):

        w1, w2 = .5, 1

        h = Histogram(binwidth=w1)
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert np.all(np.diff(e1) == w1)
        assert np.all(np.diff(e2) == w1)

        h = Histogram(binwidth=(w1, w2))
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert np.all(np.diff(e1) == w1)
        assert np.all(np.diff(e2) == w2)

    def test_binrange(self, x):

        binrange = (-4, 4)
        h = Histogram(binrange=binrange)
        bin_kws = h.define_bin_params(x)
        assert bin_kws["range"] == binrange

    def test_bivariate_binrange(self, x, y):

        r1, r2 = (-4, 4), (-10, 10)

        h = Histogram(binrange=r1)
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert e1.min() == r1[0]
        assert e1.max() == r1[1]
        assert e2.min() == r1[0]
        assert e2.max() == r1[1]

        h = Histogram(binrange=(r1, r2))
        e1, e2 = h.define_bin_params(x, y)["bins"]
        assert e1.min() == r1[0]
        assert e1.max() == r1[1]
        assert e2.min() == r2[0]
        assert e2.max() == r2[1]

    def test_discrete_bins(self, rng):

        x = rng.binomial(20, .5, 100)
        h = Histogram(discrete=True)
        bin_kws = h.define_bin_params(x)
        assert bin_kws["range"] == (x.min() - .5, x.max() + .5)
        assert bin_kws["bins"] == (x.max() - x.min() + 1)

    def test_odd_single_observation(self):
        # GH2721
        x = np.array([0.49928])
        h, e = Histogram(binwidth=0.03)(x)
        assert len(h) == 1
        assert (e[1] - e[0]) == pytest.approx(.03)

    def test_binwidth_roundoff(self):
        # GH2785
        x = np.array([2.4, 2.5, 2.6])
        h, e = Histogram(binwidth=0.01)(x)
        assert h.sum() == 3

    def test_histogram(self, x):

        h = Histogram()
        heights, edges = h(x)
        heights_mpl, edges_mpl = np.histogram(x, bins="auto")

        assert_array_equal(heights, heights_mpl)
        assert_array_equal(edges, edges_mpl)

    def test_count_stat(self, x):

        h = Histogram(stat="count")
        heights, _ = h(x)
        assert heights.sum() == len(x)

    def test_density_stat(self, x):

        h = Histogram(stat="density")
        heights, edges = h(x)
        assert (heights * np.diff(edges)).sum() == 1

    def test_probability_stat(self, x):

        h = Histogram(stat="probability")
        heights, _ = h(x)
        assert heights.sum() == 1

    def test_frequency_stat(self, x):

        h = Histogram(stat="frequency")
        heights, edges = h(x)
        assert (heights * np.diff(edges)).sum() == len(x)

    def test_cumulative_count(self, x):

        h = Histogram(stat="count", cumulative=True)
        heights, _ = h(x)
        assert heights[-1] == len(x)

    def test_cumulative_density(self, x):

        h = Histogram(stat="density", cumulative=True)
        heights, _ = h(x)
        assert heights[-1] == 1

    def test_cumulative_probability(self, x):

        h = Histogram(stat="probability", cumulative=True)
        heights, _ = h(x)
        assert heights[-1] == 1

    def test_cumulative_frequency(self, x):

        h = Histogram(stat="frequency", cumulative=True)
        heights, _ = h(x)
        assert heights[-1] == len(x)

    def test_bivariate_histogram(self, x, y):

        h = Histogram()
        heights, edges = h(x, y)
        bins_mpl = (
            np.histogram_bin_edges(x, "auto"),
            np.histogram_bin_edges(y, "auto"),
        )
        heights_mpl, *edges_mpl = np.histogram2d(x, y, bins_mpl)
        assert_array_equal(heights, heights_mpl)
        assert_array_equal(edges[0], edges_mpl[0])
        assert_array_equal(edges[1], edges_mpl[1])

    def test_bivariate_count_stat(self, x, y):

        h = Histogram(stat="count")
        heights, _ = h(x, y)
        assert heights.sum() == len(x)

    def test_bivariate_density_stat(self, x, y):

        h = Histogram(stat="density")
        heights, (edges_x, edges_y) = h(x, y)
        areas = np.outer(np.diff(edges_x), np.diff(edges_y))
        assert (heights * areas).sum() == pytest.approx(1)

    def test_bivariate_probability_stat(self, x, y):

        h = Histogram(stat="probability")
        heights, _ = h(x, y)
        assert heights.sum() == 1

    def test_bivariate_frequency_stat(self, x, y):

        h = Histogram(stat="frequency")
        heights, (x_edges, y_edges) = h(x, y)
        area = np.outer(np.diff(x_edges), np.diff(y_edges))
        assert (heights * area).sum() == len(x)

    def test_bivariate_cumulative_count(self, x, y):

        h = Histogram(stat="count", cumulative=True)
        heights, _ = h(x, y)
        assert heights[-1, -1] == len(x)

    def test_bivariate_cumulative_density(self, x, y):

        h = Histogram(stat="density", cumulative=True)
        heights, _ = h(x, y)
        assert heights[-1, -1] == pytest.approx(1)

    def test_bivariate_cumulative_frequency(self, x, y):

        h = Histogram(stat="frequency", cumulative=True)
        heights, _ = h(x, y)
        assert heights[-1, -1] == len(x)

    def test_bivariate_cumulative_probability(self, x, y):

        h = Histogram(stat="probability", cumulative=True)
        heights, _ = h(x, y)
        assert heights[-1, -1] == pytest.approx(1)

    def test_bad_stat(self):

        with pytest.raises(ValueError):
            Histogram(stat="invalid")


class TestECDF(DistributionFixtures):

    def test_univariate_proportion(self, x):

        ecdf = ECDF()
        stat, vals = ecdf(x)
        assert_array_equal(vals[1:], np.sort(x))
        assert_array_almost_equal(stat[1:], np.linspace(0, 1, len(x) + 1)[1:])
        assert stat[0] == 0

    def test_univariate_count(self, x):

        ecdf = ECDF(stat="count")
        stat, vals = ecdf(x)

        assert_array_equal(vals[1:], np.sort(x))
        assert_array_almost_equal(stat[1:], np.arange(len(x)) + 1)
        assert stat[0] == 0

    def test_univariate_percent(self, x2):

        ecdf = ECDF(stat="percent")
        stat, vals = ecdf(x2)

        assert_array_equal(vals[1:], np.sort(x2))
        assert_array_almost_equal(stat[1:], (np.arange(len(x2)) + 1) / len(x2) * 100)
        assert stat[0] == 0

    def test_univariate_proportion_weights(self, x, weights):

        ecdf = ECDF()
        stat, vals = ecdf(x, weights=weights)
        assert_array_equal(vals[1:], np.sort(x))
        expected_stats = weights[x.argsort()].cumsum() / weights.sum()
        assert_array_almost_equal(stat[1:], expected_stats)
        assert stat[0] == 0

    def test_univariate_count_weights(self, x, weights):

        ecdf = ECDF(stat="count")
        stat, vals = ecdf(x, weights=weights)
        assert_array_equal(vals[1:], np.sort(x))
        assert_array_almost_equal(stat[1:], weights[x.argsort()].cumsum())
        assert stat[0] == 0

    @pytest.mark.skipif(smdist is None, reason="Requires statsmodels")
    def test_against_statsmodels(self, x):

        sm_ecdf = smdist.empirical_distribution.ECDF(x)

        ecdf = ECDF()
        stat, vals = ecdf(x)
        assert_array_equal(vals, sm_ecdf.x)
        assert_array_almost_equal(stat, sm_ecdf.y)

        ecdf = ECDF(complementary=True)
        stat, vals = ecdf(x)
        assert_array_equal(vals, sm_ecdf.x)
        assert_array_almost_equal(stat, sm_ecdf.y[::-1])

    def test_invalid_stat(self, x):

        with pytest.raises(ValueError, match="`stat` must be one of"):
            ECDF(stat="density")

    def test_bivariate_error(self, x, y):

        with pytest.raises(NotImplementedError, match="Bivariate ECDF"):
            ecdf = ECDF()
            ecdf(x, y)


class TestEstimateAggregator:

    def test_func_estimator(self, long_df):

        func = np.mean
        agg = EstimateAggregator(func)
        out = agg(long_df, "x")
        assert out["x"] == func(long_df["x"])

    def test_name_estimator(self, long_df):

        agg = EstimateAggregator("mean")
        out = agg(long_df, "x")
        assert out["x"] == long_df["x"].mean()

    def test_custom_func_estimator(self, long_df):

        def func(x):
            return np.asarray(x).min()

        agg = EstimateAggregator(func)
        out = agg(long_df, "x")
        assert out["x"] == func(long_df["x"])

    def test_se_errorbars(self, long_df):

        agg = EstimateAggregator("mean", "se")
        out = agg(long_df, "x")
        assert out["x"] == long_df["x"].mean()
        assert out["xmin"] == (long_df["x"].mean() - long_df["x"].sem())
        assert out["xmax"] == (long_df["x"].mean() + long_df["x"].sem())

        agg = EstimateAggregator("mean", ("se", 2))
        out = agg(long_df, "x")
        assert out["x"] == long_df["x"].mean()
        assert out["xmin"] == (long_df["x"].mean() - 2 * long_df["x"].sem())
        assert out["xmax"] == (long_df["x"].mean() + 2 * long_df["x"].sem())

    def test_sd_errorbars(self, long_df):

        agg = EstimateAggregator("mean", "sd")
        out = agg(long_df, "x")
        assert out["x"] == long_df["x"].mean()
        assert out["xmin"] == (long_df["x"].mean() - long_df["x"].std())
        assert out["xmax"] == (long_df["x"].mean() + long_df["x"].std())

        agg = EstimateAggregator("mean", ("sd", 2))
        out = agg(long_df, "x")
        assert out["x"] == long_df["x"].mean()
        assert out["xmin"] == (long_df["x"].mean() - 2 * long_df["x"].std())
        assert out["xmax"] == (long_df["x"].mean() + 2 * long_df["x"].std())

    def test_pi_errorbars(self, long_df):

        agg = EstimateAggregator("mean", "pi")
        out = agg(long_df, "y")
        assert out["ymin"] == np.percentile(long_df["y"], 2.5)
        assert out["ymax"] == np.percentile(long_df["y"], 97.5)

        agg = EstimateAggregator("mean", ("pi", 50))
        out = agg(long_df, "y")
        assert out["ymin"] == np.percentile(long_df["y"], 25)
        assert out["ymax"] == np.percentile(long_df["y"], 75)

    def test_ci_errorbars(self, long_df):

        agg = EstimateAggregator("mean", "ci", n_boot=100000, seed=0)
        out = agg(long_df, "y")

        agg_ref = EstimateAggregator("mean", ("se", 1.96))
        out_ref = agg_ref(long_df, "y")

        assert out["ymin"] == pytest.approx(out_ref["ymin"], abs=1e-2)
        assert out["ymax"] == pytest.approx(out_ref["ymax"], abs=1e-2)

        agg = EstimateAggregator("mean", ("ci", 68), n_boot=100000, seed=0)
        out = agg(long_df, "y")

        agg_ref = EstimateAggregator("mean", ("se", 1))
        out_ref = agg_ref(long_df, "y")

        assert out["ymin"] == pytest.approx(out_ref["ymin"], abs=1e-2)
        assert out["ymax"] == pytest.approx(out_ref["ymax"], abs=1e-2)

        agg = EstimateAggregator("mean", "ci", seed=0)
        out_orig = agg_ref(long_df, "y")
        out_test = agg_ref(long_df, "y")
        assert_array_equal(out_orig, out_test)

    def test_custom_errorbars(self, long_df):

        f = lambda x: (x.min(), x.max())  # noqa: E731
        agg = EstimateAggregator("mean", f)
        out = agg(long_df, "y")
        assert out["ymin"] == long_df["y"].min()
        assert out["ymax"] == long_df["y"].max()

    def test_singleton_errorbars(self):

        agg = EstimateAggregator("mean", "ci")
        val = 7
        out = agg(pd.DataFrame(dict(y=[val])), "y")
        assert out["y"] == val
        assert pd.isna(out["ymin"])
        assert pd.isna(out["ymax"])

    def test_errorbar_validation(self):

        method, level = _validate_errorbar_arg(("ci", 99))
        assert method == "ci"
        assert level == 99

        method, level = _validate_errorbar_arg("sd")
        assert method == "sd"
        assert level == 1

        f = lambda x: (x.min(), x.max())  # noqa: E731
        method, level = _validate_errorbar_arg(f)
        assert method is f
        assert level is None

        bad_args = [
            ("sem", ValueError),
            (("std", 2), ValueError),
            (("pi", 5, 95), ValueError),
            (95, TypeError),
            (("ci", "large"), TypeError),
        ]

        for arg, exception in bad_args:
            with pytest.raises(exception, match="`errorbar` must be"):
                _validate_errorbar_arg(arg)


class TestWeightedAggregator:

    def test_weighted_mean(self, long_df):

        long_df["weight"] = long_df["x"]
        est = WeightedAggregator("mean")
        out = est(long_df, "y")
        expected = np.average(long_df["y"], weights=long_df["weight"])
        assert_array_equal(out["y"], expected)
        assert_array_equal(out["ymin"], np.nan)
        assert_array_equal(out["ymax"], np.nan)

    def test_weighted_ci(self, long_df):

        long_df["weight"] = long_df["x"]
        est = WeightedAggregator("mean", "ci")
        out = est(long_df, "y")
        expected = np.average(long_df["y"], weights=long_df["weight"])
        assert_array_equal(out["y"], expected)
        assert (out["ymin"] <= out["y"]).all()
        assert (out["ymax"] >= out["y"]).all()

    def test_limited_estimator(self):

        with pytest.raises(ValueError, match="Weighted estimator must be 'mean'"):
            WeightedAggregator("median")

    def test_limited_ci(self):

        with pytest.raises(ValueError, match="Error bar method must be 'ci'"):
            WeightedAggregator("mean", "sd")


class TestLetterValues:

    @pytest.fixture
    def x(self, rng):
        return pd.Series(rng.standard_t(10, 10_000))

    def test_levels(self, x):

        res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
        k = res["k"]
        expected = np.concatenate([np.arange(k), np.arange(k - 1)[::-1]])
        assert_array_equal(res["levels"], expected)

    def test_values(self, x):

        res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
        assert_array_equal(np.percentile(x, res["percs"]), res["values"])

    def test_fliers(self, x):

        res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
        fliers = res["fliers"]
        values = res["values"]
        assert ((fliers < values.min()) | (fliers > values.max())).all()

    def test_median(self, x):

        res = LetterValues(k_depth="tukey", outlier_prop=0, trust_alpha=0)(x)
        assert res["median"] == np.median(x)

    def test_k_depth_int(self, x):

        res = LetterValues(k_depth=(k := 12), outlier_prop=0, trust_alpha=0)(x)
        assert res["k"] == k
        assert len(res["levels"]) == (2 * k - 1)

    def test_trust_alpha(self, x):

        res1 = LetterValues(k_depth="trustworthy", outlier_prop=0, trust_alpha=.1)(x)
        res2 = LetterValues(k_depth="trustworthy", outlier_prop=0, trust_alpha=.001)(x)
        assert res1["k"] > res2["k"]

    def test_outlier_prop(self, x):

        res1 = LetterValues(k_depth="proportion", outlier_prop=.001, trust_alpha=0)(x)
        res2 = LetterValues(k_depth="proportion", outlier_prop=.1, trust_alpha=0)(x)
        assert res1["k"] > res2["k"]


================================================
FILE: tests/test_utils.py
================================================
"""Tests for seaborn utility functions."""
import re
import tempfile
from types import ModuleType
from urllib.request import urlopen
from http.client import HTTPException

import numpy as np
import pandas as pd
import matplotlib as mpl
import matplotlib.pyplot as plt
from cycler import cycler

import pytest
from numpy.testing import (
    assert_array_equal,
)
from pandas.testing import (
    assert_series_equal,
    assert_frame_equal,
)

from seaborn import utils, rcmod, scatterplot
from seaborn.utils import (
    get_dataset_names,
    get_color_cycle,
    remove_na,
    load_dataset,
    _assign_default_kwargs,
    _check_argument,
    _draw_figure,
    _deprecate_ci,
    _version_predates, DATASET_NAMES_URL,
)
from seaborn._compat import get_legend_handles


a_norm = np.random.randn(100)


def _network(t=None, url="https://github.com"):
    """
    Decorator that will skip a test if `url` is unreachable.

    Parameters
    ----------
    t : function, optional
    url : str, optional

    """
    if t is None:
        return lambda x: _network(x, url=url)

    def wrapper(*args, **kwargs):
        # attempt to connect
        try:
            f = urlopen(url)
        except (OSError, HTTPException):
            pytest.skip("No internet connection")
        else:
            f.close()
            return t(*args, **kwargs)
    return wrapper


def test_ci_to_errsize():
    """Test behavior of ci_to_errsize."""
    cis = [[.5, .5],
           [1.25, 1.5]]

    heights = [1, 1.5]

    actual_errsize = np.array([[.5, 1],
                               [.25, 0]])

    test_errsize = utils.ci_to_errsize(cis, heights)
    assert_array_equal(actual_errsize, test_errsize)


def test_desaturate():
    """Test color desaturation."""
    out1 = utils.desaturate("red", .5)
    assert out1 == (.75, .25, .25)

    out2 = utils.desaturate("#00FF00", .5)
    assert out2 == (.25, .75, .25)

    out3 = utils.desaturate((0, 0, 1), .5)
    assert out3 == (.25, .25, .75)

    out4 = utils.desaturate("red", .5)
    assert out4 == (.75, .25, .25)

    out5 = utils.desaturate("lightblue", 1)
    assert out5 == mpl.colors.to_rgb("lightblue")


def test_desaturation_prop():
    """Test that pct outside of [0, 1] raises exception."""
    with pytest.raises(ValueError):
        utils.desaturate("blue", 50)


def test_saturate():
    """Test performance of saturation function."""
    out = utils.saturate((.75, .25, .25))
    assert out == (1, 0, 0)


@pytest.mark.parametrize(
    "s,exp",
    [
        ("a", "a"),
        ("abc", "abc"),
        (b"a", "a"),
        (b"abc", "abc"),
        (bytearray("abc", "utf-8"), "abc"),
        (bytearray(), ""),
        (1, "1"),
        (0, "0"),
        ([], str([])),
    ],
)
def test_to_utf8(s, exp):
    """Test the to_utf8 function: object to string"""
    u = utils.to_utf8(s)
    assert isinstance(u, str)
    assert u == exp


class TestSpineUtils:

    sides = ["left", "right", "bottom", "top"]
    outer_sides = ["top", "right"]
    inner_sides = ["left", "bottom"]

    offset = 10
    original_position = ("outward", 0)
    offset_position = ("outward", offset)

    def test_despine(self):
        f, ax = plt.subplots()
        for side in self.sides:
            assert ax.spines[side].get_visible()

        utils.despine()
        for side in self.outer_sides:
            assert not ax.spines[side].get_visible()
        for side in self.inner_sides:
            assert ax.spines[side].get_visible()

        utils.despine(**dict(zip(self.sides, [True] * 4)))
        for side in self.sides:
            assert not ax.spines[side].get_visible()

    def test_despine_specific_axes(self):
        f, (ax1, ax2) = plt.subplots(2, 1)

        utils.despine(ax=ax2)

        for side in self.sides:
            assert ax1.spines[side].get_visible()

        for side in self.outer_sides:
            assert not ax2.spines[side].get_visible()
        for side in self.inner_sides:
            assert ax2.spines[side].get_visible()

    def test_despine_with_offset(self):
        f, ax = plt.subplots()

        for side in self.sides:
            pos = ax.spines[side].get_position()
            assert pos == self.original_position

        utils.despine(ax=ax, offset=self.offset)

        for side in self.sides:
            is_visible = ax.spines[side].get_visible()
            new_position = ax.spines[side].get_position()
            if is_visible:
                assert new_position == self.offset_position
            else:
                assert new_position == self.original_position

    def test_despine_side_specific_offset(self):

        f, ax = plt.subplots()
        utils.despine(ax=ax, offset=dict(left=self.offset))

        for side in self.sides:
            is_visible = ax.spines[side].get_visible()
            new_position = ax.spines[side].get_position()
            if is_visible and side == "left":
                assert new_position == self.offset_position
            else:
                assert new_position == self.original_position

    def test_despine_with_offset_specific_axes(self):
        f, (ax1, ax2) = plt.subplots(2, 1)

        utils.despine(offset=self.offset, ax=ax2)

        for side in self.sides:
            pos1 = ax1.spines[side].get_position()
            pos2 = ax2.spines[side].get_position()
            assert pos1 == self.original_position
            if ax2.spines[side].get_visible():
                assert pos2 == self.offset_position
            else:
                assert pos2 == self.original_position

    def test_despine_trim_spines(self):

        f, ax = plt.subplots()
        ax.plot([1, 2, 3], [1, 2, 3])
        ax.set_xlim(.75, 3.25)

        utils.despine(trim=True)
        for side in self.inner_sides:
            bounds = ax.spines[side].get_bounds()
            assert bounds == (1, 3)

    def test_despine_trim_inverted(self):

        f, ax = plt.subplots()
        ax.plot([1, 2, 3], [1, 2, 3])
        ax.set_ylim(.85, 3.15)
        ax.invert_yaxis()

        utils.despine(trim=True)
        for side in self.inner_sides:
            bounds = ax.spines[side].get_bounds()
            assert bounds == (1, 3)

    def test_despine_trim_noticks(self):

        f, ax = plt.subplots()
        ax.plot([1, 2, 3], [1, 2, 3])
        ax.set_yticks([])
        utils.despine(trim=True)
        assert ax.get_yticks().size == 0

    def test_despine_trim_categorical(self):

        f, ax = plt.subplots()
        ax.plot(["a", "b", "c"], [1, 2, 3])

        utils.despine(trim=True)

        bounds = ax.spines["left"].get_bounds()
        assert bounds == (1, 3)

        bounds = ax.spines["bottom"].get_bounds()
        assert bounds == (0, 2)

    def test_despine_moved_ticks(self):

        f, ax = plt.subplots()
        for t in ax.yaxis.majorTicks:
            t.tick1line.set_visible(True)
        utils.despine(ax=ax, left=True, right=False)
        for t in ax.yaxis.majorTicks:
            assert t.tick2line.get_visible()
        plt.close(f)

        f, ax = plt.subplots()
        for t in ax.yaxis.majorTicks:
            t.tick1line.set_visible(False)
        utils.despine(ax=ax, left=True, right=False)
        for t in ax.yaxis.majorTicks:
            assert not t.tick2line.get_visible()
        plt.close(f)

        f, ax = plt.subplots()
        for t in ax.xaxis.majorTicks:
            t.tick1line.set_visible(True)
        utils.despine(ax=ax, bottom=True, top=False)
        for t in ax.xaxis.majorTicks:
            assert t.tick2line.get_visible()
        plt.close(f)

        f, ax = plt.subplots()
        for t in ax.xaxis.majorTicks:
            t.tick1line.set_visible(False)
        utils.despine(ax=ax, bottom=True, top=False)
        for t in ax.xaxis.majorTicks:
            assert not t.tick2line.get_visible()
        plt.close(f)


def test_ticklabels_overlap():

    rcmod.set()
    f, ax = plt.subplots(figsize=(2, 2))
    f.tight_layout()  # This gets the Agg renderer working

    assert not utils.axis_ticklabels_overlap(ax.get_xticklabels())

    big_strings = "abcdefgh", "ijklmnop"
    ax.set_xlim(-.5, 1.5)
    ax.set_xticks([0, 1])
    ax.set_xticklabels(big_strings)

    assert utils.axis_ticklabels_overlap(ax.get_xticklabels())

    x, y = utils.axes_ticklabels_overlap(ax)
    assert x
    assert not y


def test_locator_to_legend_entries():

    locator = mpl.ticker.MaxNLocator(nbins=3)
    limits = (0.09, 0.4)
    levels, str_levels = utils.locator_to_legend_entries(
        locator, limits, float
    )
    assert str_levels == ["0.15", "0.30"]

    limits = (0.8, 0.9)
    levels, str_levels = utils.locator_to_legend_entries(
        locator, limits, float
    )
    assert str_levels == ["0.80", "0.84", "0.88"]

    limits = (1, 6)
    levels, str_levels = utils.locator_to_legend_entries(locator, limits, int)
    assert str_levels == ["2", "4", "6"]

    locator = mpl.ticker.LogLocator(numticks=5)
    limits = (5, 1425)
    levels, str_levels = utils.locator_to_legend_entries(locator, limits, int)
    assert str_levels == ['10', '100', '1000']

    limits = (0.00003, 0.02)
    _, str_levels = utils.locator_to_legend_entries(locator, limits, float)
    for i, exp in enumerate([4, 3, 2]):
        # Use regex as mpl switched to minus sign, not hyphen, in 3.6
        assert re.match(f"1e.0{exp}", str_levels[i])


def test_move_legend_matplotlib_objects():

    fig, ax = plt.subplots()

    colors = "C2", "C5"
    labels = "first label", "second label"
    title = "the legend"

    for color, label in zip(colors, labels):
        ax.plot([0, 1], color=color, label=label)
    ax.legend(loc="upper right", title=title)
    utils._draw_figure(fig)
    xfm = ax.transAxes.inverted().transform

    # --- Test axes legend

    old_pos = xfm(ax.legend_.legendPatch.get_extents())

    new_fontsize = 14
    utils.move_legend(ax, "lower left", title_fontsize=new_fontsize)
    utils._draw_figure(fig)
    new_pos = xfm(ax.legend_.legendPatch.get_extents())

    assert (new_pos < old_pos).all()
    assert ax.legend_.get_title().get_text() == title
    assert ax.legend_.get_title().get_size() == new_fontsize

    # --- Test title replacement

    new_title = "new title"
    utils.move_legend(ax, "lower left", title=new_title)
    utils._draw_figure(fig)
    assert ax.legend_.get_title().get_text() == new_title

    # --- Test figure legend

    fig.legend(loc="upper right", title=title)
    _draw_figure(fig)
    xfm = fig.transFigure.inverted().transform
    old_pos = xfm(fig.legends[0].legendPatch.get_extents())

    utils.move_legend(fig, "lower left", title=new_title)
    _draw_figure(fig)

    new_pos = xfm(fig.legends[0].legendPatch.get_extents())
    assert (new_pos < old_pos).all()
    assert fig.legends[0].get_title().get_text() == new_title


def test_move_legend_grid_object(long_df):

    from seaborn.axisgrid import FacetGrid

    hue_var = "a"
    g = FacetGrid(long_df, hue=hue_var)
    g.map(plt.plot, "x", "y")

    g.add_legend()
    _draw_figure(g.figure)

    xfm = g.figure.transFigure.inverted().transform
    old_pos = xfm(g.legend.legendPatch.get_extents())

    fontsize = 20
    utils.move_legend(g, "lower left", title_fontsize=fontsize)
    _draw_figure(g.figure)

    new_pos = xfm(g.legend.legendPatch.get_extents())
    assert (new_pos < old_pos).all()
    assert g.legend.get_title().get_text() == hue_var
    assert g.legend.get_title().get_size() == fontsize

    assert get_legend_handles(g.legend)
    for i, h in enumerate(get_legend_handles(g.legend)):
        assert mpl.colors.to_rgb(h.get_color()) == mpl.colors.to_rgb(f"C{i}")


def test_move_legend_input_checks():

    ax = plt.figure().subplots()
    with pytest.raises(TypeError):
        utils.move_legend(ax.xaxis, "best")

    with pytest.raises(ValueError):
        utils.move_legend(ax, "best")

    with pytest.raises(ValueError):
        utils.move_legend(ax.figure, "best")


def test_move_legend_with_labels(long_df):

    order = long_df["a"].unique()
    labels = [s.capitalize() for s in order]
    ax = scatterplot(long_df, x="x", y="y", hue="a", hue_order=order)

    handles_before = get_legend_handles(ax.get_legend())
    colors_before = [h.get_markerfacecolor() for h in handles_before]
    utils.move_legend(ax, "best", labels=labels)
    _draw_figure(ax.figure)

    texts = [t.get_text() for t in ax.get_legend().get_texts()]
    assert texts == labels

    handles_after = get_legend_handles(ax.get_legend())
    colors_after = [h.get_markerfacecolor() for h in handles_after]
    assert colors_before == colors_after

    with pytest.raises(ValueError, match="Length of new labels"):
        utils.move_legend(ax, "best", labels=labels[:-1])


def check_load_dataset(name):
    ds = load_dataset(name, cache=False)
    assert isinstance(ds, pd.DataFrame)


def check_load_cached_dataset(name):
    # Test the caching using a temporary file.
    with tempfile.TemporaryDirectory() as tmpdir:
        # download and cache
        ds = load_dataset(name, cache=True, data_home=tmpdir)

        # use cached version
        ds2 = load_dataset(name, cache=True, data_home=tmpdir)
        assert_frame_equal(ds, ds2)


@_network(url=DATASET_NAMES_URL)
def test_get_dataset_names():
    names = get_dataset_names()
    assert names
    assert "tips" in names


@_network(url=DATASET_NAMES_URL)
def test_load_datasets():

    # Heavy test to verify that we can load all available datasets
    for name in get_dataset_names():
        # unfortunately @network somehow obscures this generator so it
        # does not get in effect, so we need to call explicitly
        # yield check_load_dataset, name
        check_load_dataset(name)


@_network(url=DATASET_NAMES_URL)
def test_load_dataset_string_error():

    name = "bad_name"
    err = f"'{name}' is not one of the example datasets."
    with pytest.raises(ValueError, match=err):
        load_dataset(name)


def test_load_dataset_passed_data_error():

    df = pd.DataFrame()
    err = "This function accepts only strings"
    with pytest.raises(TypeError, match=err):
        load_dataset(df)


@_network(url="https://github.com/mwaskom/seaborn-data")
def test_load_cached_datasets():

    # Heavy test to verify that we can load all available datasets
    for name in get_dataset_names():
        # unfortunately @network somehow obscures this generator so it
        # does not get in effect, so we need to call explicitly
        # yield check_load_dataset, name
        check_load_cached_dataset(name)


def test_relative_luminance():
    """Test relative luminance."""
    out1 = utils.relative_luminance("white")
    assert out1 == 1

    out2 = utils.relative_luminance("#000000")
    assert out2 == 0

    out3 = utils.relative_luminance((.25, .5, .75))
    assert out3 == pytest.approx(0.201624536)

    rgbs = mpl.cm.RdBu(np.linspace(0, 1, 10))
    lums1 = [utils.relative_luminance(rgb) for rgb in rgbs]
    lums2 = utils.relative_luminance(rgbs)

    for lum1, lum2 in zip(lums1, lums2):
        assert lum1 == pytest.approx(lum2)


@pytest.mark.parametrize(
    "cycler,result",
    [
        (cycler(color=["y"]), ["y"]),
        (cycler(color=["k"]), ["k"]),
        (cycler(color=["k", "y"]), ["k", "y"]),
        (cycler(color=["y", "k"]), ["y", "k"]),
        (cycler(color=["b", "r"]), ["b", "r"]),
        (cycler(color=["r", "b"]), ["r", "b"]),
        (cycler(lw=[1, 2]), [".15"]),  # no color in cycle
    ],
)
def test_get_color_cycle(cycler, result):
    with mpl.rc_context(rc={"axes.prop_cycle": cycler}):
        assert get_color_cycle() == result


def test_remove_na():

    a_array = np.array([1, 2, np.nan, 3])
    a_array_rm = remove_na(a_array)
    assert_array_equal(a_array_rm, np.array([1, 2, 3]))

    a_series = pd.Series([1, 2, np.nan, 3])
    a_series_rm = remove_na(a_series)
    assert_series_equal(a_series_rm, pd.Series([1., 2, 3], [0, 1, 3]))


def test_assign_default_kwargs():

    def f(a, b, c, d):
        pass

    def g(c=1, d=2):
        pass

    kws = {"c": 3}

    kws = _assign_default_kwargs(kws, f, g)
    assert kws == {"c": 3, "d": 2}


def test_check_argument():

    opts = ["a", "b", None]
    assert _check_argument("arg", opts, "a") == "a"
    assert _check_argument("arg", opts, None) is None
    assert _check_argument("arg", opts, "aa", prefix=True) == "aa"
    assert _check_argument("arg", opts, None, prefix=True) is None
    with pytest.raises(ValueError, match="The value for `arg`"):
        _check_argument("arg", opts, "c")
    with pytest.raises(ValueError, match="The value for `arg`"):
        _check_argument("arg", opts, "c", prefix=True)
    with pytest.raises(ValueError, match="The value for `arg`"):
        _check_argument("arg", opts[:-1], None)
    with pytest.raises(ValueError, match="The value for `arg`"):
        _check_argument("arg", opts[:-1], None, prefix=True)


def test_draw_figure():

    f, ax = plt.subplots()
    ax.plot(["a", "b", "c"], [1, 2, 3])
    _draw_figure(f)
    assert not f.stale
    # ticklabels are not populated until a draw, but this may change
    assert ax.get_xticklabels()[0].get_text() == "a"


def test_deprecate_ci():

    msg = "\n\nThe `ci` parameter is deprecated. Use `errorbar="

    with pytest.warns(FutureWarning, match=msg + "None"):
        out = _deprecate_ci(None, None)
    assert out is None

    with pytest.warns(FutureWarning, match=msg + "'sd'"):
        out = _deprecate_ci(None, "sd")
    assert out == "sd"

    with pytest.warns(FutureWarning, match=msg + r"\('ci', 68\)"):
        out = _deprecate_ci(None, 68)
    assert out == ("ci", 68)


def test_version_predates():

    mock = ModuleType("mock")
    mock.__version__ = "1.2.3"

    assert _version_predates(mock, "1.2.4")
    assert _version_predates(mock, "1.3")

    assert not _version_predates(mock, "1.2.3")
    assert not _version_predates(mock, "0.8")
    assert not _version_predates(mock, "1")