Repository: google/latexify_py
Branch: main
Commit: 54dc86971963
Files: 66
Total size: 275.1 KB
Directory structure:
gitextract__eefs9oe/
├── .github/
│ ├── ISSUE_TEMPLATE/
│ │ ├── bug_report.md
│ │ └── feature_request.md
│ ├── pull_request_template.md
│ └── workflows/
│ ├── ci.yml
│ └── release.yml
├── .gitignore
├── CODEOWNERS
├── CONTRIBUTING.md
├── LICENSE
├── README.md
├── checks.sh
├── docs/
│ ├── getting_started.md
│ ├── index.md
│ └── parameters.md
├── examples/
│ └── latexify_examples.ipynb
├── pyproject.toml
└── src/
├── integration_tests/
│ ├── __init__.py
│ ├── algorithmic_style_test.py
│ ├── function_expansion_test.py
│ ├── integration_utils.py
│ └── regression_test.py
└── latexify/
├── __init__.py
├── _version.py
├── analyzers.py
├── analyzers_test.py
├── ast_utils.py
├── ast_utils_test.py
├── codegen/
│ ├── __init__.py
│ ├── algorithmic_codegen.py
│ ├── algorithmic_codegen_test.py
│ ├── codegen_utils.py
│ ├── codegen_utils_test.py
│ ├── expression_codegen.py
│ ├── expression_codegen_test.py
│ ├── expression_rules.py
│ ├── expression_rules_test.py
│ ├── function_codegen.py
│ ├── function_codegen_match_test.py
│ ├── function_codegen_test.py
│ ├── identifier_converter.py
│ ├── identifier_converter_test.py
│ ├── latex.py
│ └── latex_test.py
├── config.py
├── exceptions.py
├── frontend.py
├── frontend_test.py
├── generate_latex.py
├── generate_latex_test.py
├── ipython_wrappers.py
├── parser.py
├── parser_test.py
├── test_utils.py
└── transformers/
├── __init__.py
├── assignment_reducer.py
├── assignment_reducer_test.py
├── aug_assign_replacer.py
├── aug_assign_replacer_test.py
├── docstring_remover.py
├── docstring_remover_test.py
├── function_expander.py
├── function_expander_test.py
├── identifier_replacer.py
├── identifier_replacer_test.py
├── prefix_trimmer.py
└── prefix_trimmer_test.py
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/ISSUE_TEMPLATE/bug_report.md
================================================
---
name: Bug report
about: Create a report to help us improve
title: ''
labels: triage
assignees: odashi
---
## Environment
If you used latexify on the browser, fill the following items.
* Browser:
* Frontend:
If you used latexify in your own environment, fill at least the following items.
Feel free to add other items if you think they are useful.
* OS:
* Python:
* Package manager:
* Latexify version:
## Description
Describe the details of the issue. Feel free to insert screenshots if they are useful.
## Reproduction
Describe how to reproduce the issue by other people.
## Expected behavior
Describe how latexify should behave in the case above.
================================================
FILE: .github/ISSUE_TEMPLATE/feature_request.md
================================================
---
name: Feature request
about: Suggest an idea for this project
title: ''
labels: feature
assignees: odashi
---
## Description
Is your feature request related to a problem? Please describe it.
A clear and concise description is recommended to proceed the discussion efficiently.
## Ideas of the solution
If you have an idea about the solution you'd like, describe details about it.
## Alternative ideas
If you have other ideas that are already considered, describe them as well.
These ideas may also help us to make reasonable decisions.
## Additional context
Add any other context or screenshots about the feature request here.
================================================
FILE: .github/pull_request_template.md
================================================
# Overview
# Details
# References
# Blocked by
================================================
FILE: .github/workflows/ci.yml
================================================
name: Continuous integration
on:
push:
branches:
- main
pull_request:
branches: ["**"]
jobs:
unit-tests:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"]
steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install -e ".[dev]"
- name: Test
run: python -m pytest src
black:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install black
- name: Check
run: python -m black -v --check src
flake8:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pyproject-flake8
- name: Check
run: pflake8 -v src
isort:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install isort
- name: Check
run: python -m isort --check src
mypy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- name: Set up Python
uses: actions/setup-python@v4
with:
python-version: "3.11"
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install '.[mypy]'
- name: Check
run: python -m mypy src
================================================
FILE: .github/workflows/release.yml
================================================
name: Release workflow
on:
push:
tags:
- "v[0123456789].*"
jobs:
release:
runs-on: ubuntu-latest
steps:
- name: checkout
uses: actions/checkout@v3
- name: setup python
uses: actions/setup-python@v2
with:
python-version: "3.10"
- name: build
run: |
python -m pip install --upgrade build hatch
python -m hatch version "${GITHUB_REF_NAME}"
python -m build
- name: publish
uses: pypa/gh-action-pypi-publish@release/v1
with:
password: ${{ secrets.PYPI_API_TOKEN }}
================================================
FILE: .gitignore
================================================
# Temporary files
.swp
temp
tmp
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# PyCharm project settings
.idea/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
================================================
FILE: CODEOWNERS
================================================
* @odashi
================================================
FILE: CONTRIBUTING.md
================================================
# How to Contribute
We'd love to accept your patches and contributions to this project. There are
just a few small guidelines you need to follow.
## Contributor License Agreement
Contributions to this project must be accompanied by a Contributor License
Agreement (CLA). You (or your employer) retain the copyright to your
contribution; this simply gives us permission to use and redistribute your
contributions as part of the project. Head over to
to see your current agreements on file or
to sign a new one.
You generally only need to submit a CLA once, so if you've already submitted one
(even if it was for a different project), you probably don't need to do it
again.
## Code reviews
All submissions, including submissions by project members, require review. We
use GitHub pull requests for this purpose. Consult
[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more
information on using pull requests.
## Community Guidelines
This project follows
[Google's Open Source Community Guidelines](https://opensource.google/conduct/).
## Coding style
This project follows
[Tensorflow's style](https://www.tensorflow.org/community/contribute/code_style).
================================================
FILE: LICENSE
================================================
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
================================================
FILE: README.md
================================================
# latexify
[](https://pypi.org/project/latexify-py/)
[](https://pypi.org/project/latexify-py/)
[](https://github.com/google/latexify_py/blob/main/LICENSE)
[](https://pepy.tech/project/latexify-py)
[](https://github.com/psf/black)
[](https://pycqa.github.io/isort/)
`latexify` is a Python package to compile a fragment of Python source code to a
corresponding $\LaTeX$ expression:

`latexify` provides the following functionalities:
* Libraries to compile Python source code or AST to $\LaTeX$.
* IPython classes to pretty-print compiled functions.
## FAQs
1. *Which Python versions are supported?*
Syntaxes on **Pythons 3.9 to 3.13** are officially supported, or will be supported.
2. *Which technique is used?*
`latexify` is implemented as a rule-based system on the official `ast` package.
3. *Are "AI" techniques adopted?*
`latexify` is based on traditional parsing techniques.
If the "AI" meant some techniques around machine learning, the answer is no.
## Getting started
See the
[example notebook](https://github.com/google/latexify_py/blob/main/examples/latexify_examples.ipynb),
which provides several
use-cases of this library.
You can also try the above notebook on
[Google Colaboratory](https://colab.research.google.com/github/google/latexify_py/blob/main/examples/latexify_examples.ipynb).
See also the official
[documentation](https://github.com/google/latexify_py/blob/main/docs/index.md)
for more details.
## How to Contribute
To contribute to this project, please refer
[CONTRIBUTING.md](https://github.com/google/latexify_py/blob/develop/CONTRIBUTING.md).
## Disclaimer
This software is currently hosted on , but not officially
supported by Google.
If you have any issues and/or questions about this software, please visit the
[issue tracker](https://github.com/google/latexify_py/issues)
or contact the [main maintainer](https://github.com/odashi).
## License
This software adopts the
[Apache License 2.0](https://github.com/google/latexify_py/blob/develop/LICENSE).
================================================
FILE: checks.sh
================================================
#!/bin/bash
set -eoux pipefail
python -m pytest src -vv
python -m black --check src
python -m pflake8 src
python -m isort --check src
python -m mypy src
================================================
FILE: docs/getting_started.md
================================================
# Getting started
This document describes how to use `latexify` with your Python code.
## Installation
`latexify` depends on only Python libraries at this point.
You can simply install `latexify` via `pip`:
```shell
$ pip install latexify-py
```
Note that you have to install `latexify-py` rather than `latexify`.
## Using `latexify` in Jupyter
`latexify.function` decorator function wraps your functions to pretty-print them as
corresponding LaTeX formulas.
Jupyter recognizes this wrapper and try to print LaTeX instead of the original function.
The following snippet:
```python
@latexify.function
def solve(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)
solve
```
will print the following formula to the output:
$$ \mathrm{solve}(a, b, c) = \frac{-b + \sqrt{b^2 - 4ac}}{2a} $$
Invoking wrapped functions work transparently as the original function.
```python
solve(1, 2, 1)
```
```
-1.0
```
Applying `str` to the wrapped function returns the underlying LaTeX source.
```python
print(solve)
```
```
f(n) = \\frac{-b + \\sqrt{b^{2} - 4ac}}{2a}
```
`latexify.expression` works similarly to `latexify.function`,
but it prints the function without its signature:
```python
@latexify.expression
def solve(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)
solve
```
$$ \frac{-b + \sqrt{b^2 - 4ac}}{2a} $$
## Obtaining LaTeX expression directly
You can also use `latexify.get_latex`, which takes a function and directly returns the
LaTeX expression corresponding to the given function.
The same parameters with `latexify.function` can be applied to `latexify.get_latex` as
well.
```python
def solve(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)
latexify.get_latex(solve)
```
```
f(n) = \\frac{-b + \\sqrt{b^{2} - 4ac}}{2a}
```
================================================
FILE: docs/index.md
================================================
# `latexify` documentation
## Index
* [Getting started](getting_started.md)
* [Parameters](parameters.md)
## External resources
* [Examples on Google Colaboratory](https://colab.research.google.com/drive/1MuiawKpVIZ12MWwyYuzZHmbKThdM5wNJ?usp=sharing)
================================================
FILE: docs/parameters.md
================================================
# `latexify` parameters
This document describes the list of parameters to control the behavior of `latexify`.
## `identifiers: dict[str, str]`
Key-value pair of identifiers to replace.
```python
identifiers = {
"my_function": "f",
"my_inner_function": "g",
"my_argument": "x",
}
@latexify.function(identifiers=identifiers)
def my_function(my_argument):
return my_inner_function(my_argument)
my_function
```
$$f(x) = \mathrm{g}\left(x\right)$$
## `reduce_assignments: bool`
Whether to compose all variables defined before the `return` statement.
The current version of `latexify` recognizes only the assignment statements.
Analyzing functions with other control flows may raise errors.
```python
@latexify.function(reduce_assignments=True)
def f(a, b, c):
discriminant = b**2 - 4 * a * c
numerator = -b + math.sqrt(discriminant)
denominator = 2 * a
return numerator / denominator
f
```
$$f(a, b, c) = \frac{-b + \sqrt{b^{2} - 4 a c}}{2 a}$$
## `use_math_symbols: bool`
Whether to automatically convert variables with symbol names into LaTeX symbols or not.
```python
@latexify.function(use_math_symbols=True)
def greek(alpha, beta, gamma, Omega):
return alpha * beta + math.gamma(gamma) + Omega
greek
```
$$\mathrm{greek}({\alpha}, {\beta}, {\gamma}, {\Omega}) = {\alpha} {\beta} + \Gamma\left({{\gamma}}\right) + {\Omega}$$
## `use_set_symbols: bool`
Whether to use binary operators for set operations or not.
```python
@latexify.function(use_set_symbols=True)
def f(x, y):
return x & y, x | y, x - y, x ^ y, x < y, x <= y, x > y, x >= y
f
```
$$f(x, y) = \left( x \cap y\space,\space x \cup y\space,\space x \setminus y\space,\space x \mathbin{\triangle} y\space,\space {x \subset y}\space,\space {x \subseteq y}\space,\space {x \supset y}\space,\space {x \supseteq y}\right)$$
## `use_signature: bool`
Whether to output the function signature or not.
The default value of this flag depends on the frontend function.
`True` is used in `latexify.function`, while `False` is used in `latexify.expression`.
```python
@latexify.function(use_signature=False)
def f(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)
f
```
$$\frac{-b + \sqrt{b^{2} - 4 a c}}{2 a}$$
================================================
FILE: examples/latexify_examples.ipynb
================================================
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "W5mNJI3Bnl6n"
},
"source": [
"# `latexify` examples\n",
"\n",
"This notebook provides several examples to use `latexify`.\n",
"\n",
"See also the\n",
"[official documentation](https://github.com/google/latexify_py/blob/documentation/docs/index.md)\n",
"for more details.\n",
"\n",
"If you have any questions, please ask it in the\n",
"[issue tracker](https://github.com/google/latexify_py/issues)."
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "fWCVgcRHoLd8"
},
"source": [
"## Install `latexify`"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "4IPGyu2dFH6T",
"outputId": "471cab8d-3069-4a27-f3ff-67ba177ec58d"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Collecting latexify-py\n",
" Downloading latexify_py-0.4.2-py3-none-any.whl (38 kB)\n",
"Collecting dill>=0.3.2 (from latexify-py)\n",
" Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25hInstalling collected packages: dill, latexify-py\n",
"Successfully installed dill-0.3.7 latexify-py-0.4.2\n"
]
}
],
"source": [
"# Restart the runtime before running the examples below.\n",
"%pip install latexify-py\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "-Mzq4_dNoSmc"
},
"source": [
"## Import `latexify` into your code"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "hViDMhyMFNCO",
"outputId": "b46edb25-5952-4cff-da1e-d65e7e3caad0"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'0.4.2'"
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import math # Optional\n",
"import numpy as np # Optional\n",
"import latexify\n",
"\n",
"latexify.__version__\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "4QJ6I2s7odX1"
},
"source": [
"## Examples"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "NvbEYSwXFaeE",
"outputId": "5d0ca2a4-a285-4053-9cc4-3776746443be"
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"-1.0\n",
"\\mathrm{solve}(a, b, c) = \\frac{-b + \\sqrt{ b^{2} - 4 a c }}{2 a}\n"
]
}
],
"source": [
"@latexify.function\n",
"def solve(a, b, c):\n",
" return (-b + math.sqrt(b**2 - 4*a*c)) / (2*a)\n",
"\n",
"print(solve(1, 4, 3)) # Invoking the function works as expected.\n",
"print(solve) # Printing the function shows the underlying LaTeX source.\n",
"solve # Displays the expression.\n",
"\n",
"# Writes the underlying LaTeX source into a file.\n",
"with open(\"compiled.tex\", \"w\") as fp:\n",
" print(solve, file=fp)\n"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 56
},
"id": "wS7BhtPgjSak",
"outputId": "76a8547c-e6b5-458d-aeb2-f9df2f35f7c7"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle \\frac{-b + \\sqrt{ b^{2} - 4 a c }}{2 a} $$"
],
"text/plain": [
""
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# latexify.expression works similarly, but does not output the signature.\n",
"@latexify.expression\n",
"def solve(a, b, c):\n",
" return (-b + math.sqrt(b**2 - 4*a*c)) / (2*a)\n",
"\n",
"solve\n"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 35
},
"id": "G73dnoqqjg4A",
"outputId": "b9f53cf8-4a34-452c-8d9b-946ddd0998df"
},
"outputs": [
{
"data": {
"application/vnd.google.colaboratory.intrinsic+json": {
"type": "string"
},
"text/plain": [
"'\\\\mathrm{solve}(a, b, c) = \\\\frac{-b + \\\\sqrt{ b^{2} - 4 a c }}{2 a}'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# latexify.get_latex obtains the underlying LaTeX expression directly.\n",
"def solve(a, b, c):\n",
" return (-b + math.sqrt(b**2 - 4*a*c)) / (2*a)\n",
"\n",
"latexify.get_latex(solve)\n"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 58
},
"id": "8bYSWIngGF8E",
"outputId": "669e070d-2718-49cb-a2fe-0defe0286b27"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle \\mathrm{sinc}(x) = \\left\\{ \\begin{array}{ll} 1, & \\mathrm{if} \\ x = 0 \\\\ \\frac{\\sin x}{x}, & \\mathrm{otherwise} \\end{array} \\right. $$"
],
"text/plain": [
""
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"@latexify.function\n",
"def sinc(x):\n",
" if x == 0:\n",
" return 1\n",
" else:\n",
" return math.sin(x) / x\n",
"\n",
"sinc\n"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 78
},
"id": "h1i4BjdgHjxl",
"outputId": "e448ff37-4753-4090-b2b1-1ef21b279b34"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle \\mathrm{fib}(x) = \\left\\{ \\begin{array}{ll} 0, & \\mathrm{if} \\ x = 0 \\\\ 1, & \\mathrm{if} \\ x = 1 \\\\ \\mathrm{fib} \\mathopen{}\\left( x - 1 \\mathclose{}\\right) + \\mathrm{fib} \\mathopen{}\\left( x - 2 \\mathclose{}\\right), & \\mathrm{otherwise} \\end{array} \\right. $$"
],
"text/plain": [
""
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Elif or nested else-if are unrolled.\n",
"@latexify.function\n",
"def fib(x):\n",
" if x == 0:\n",
" return 0\n",
" elif x == 1:\n",
" return 1\n",
" else:\n",
" return fib(x-1) + fib(x-2)\n",
"\n",
"fib\n"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 39
},
"id": "-JhJMAXM7j-X",
"outputId": "a47dcd59-2ff9-4aa1-935d-7c789b39057e"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle \\mathrm{greek}(\\alpha, \\beta, \\gamma, \\Omega) = \\alpha \\beta + \\Gamma \\mathopen{}\\left( \\gamma \\mathclose{}\\right) + \\Omega $$"
],
"text/plain": [
""
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Some math symbols are converted automatically.\n",
"@latexify.function(use_math_symbols=True)\n",
"def greek(alpha, beta, gamma, Omega):\n",
" return alpha * beta + math.gamma(gamma) + Omega\n",
"\n",
"greek\n"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 39
},
"id": "ySyNPS0y4tzu",
"outputId": "2d95b5ce-a9b8-42b1-eb55-dc8bd0097d69"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle f(x) = g \\mathopen{}\\left( x \\mathclose{}\\right) $$"
],
"text/plain": [
""
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Function names, arguments, variables can be replaced.\n",
"identifiers = {\n",
" \"my_function\": \"f\",\n",
" \"my_inner_function\": \"g\",\n",
" \"my_argument\": \"x\",\n",
"}\n",
"\n",
"@latexify.function(identifiers=identifiers)\n",
"def my_function(my_argument):\n",
" return my_inner_function(my_argument)\n",
"\n",
"my_function\n"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 56
},
"id": "TyacQaDM4Ei7",
"outputId": "8e971bbd-2c74-45d2-d0fa-7f46569b10a6"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle f(a, b, c) = \\frac{-b + \\sqrt{ b^{2} - 4 a c }}{2 a} $$"
],
"text/plain": [
""
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Assignments can be reduced into one expression.\n",
"@latexify.function(reduce_assignments=True)\n",
"def f(a, b, c):\n",
" discriminant = b**2 - 4 * a * c\n",
" numerator = -b + math.sqrt(discriminant)\n",
" denominator = 2 * a\n",
" return numerator / denominator\n",
"\n",
"f\n"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 78
},
"id": "oD8MFS2WE-2U",
"outputId": "f9fad1bd-b7eb-41cc-8743-ec0d80cca8bc"
},
"outputs": [
{
"data": {
"text/latex": [
"$$ \\displaystyle \\mathrm{transform}(x, y, a, b, \\theta, s, t) = \\begin{bmatrix} 1 & 0 & s \\\\ 0 & 1 & t \\\\ 0 & 0 & 1 \\end{bmatrix} \\cdot \\begin{bmatrix} \\cos \\theta & -\\sin \\theta & 0 \\\\ \\sin \\theta & \\cos \\theta & 0 \\\\ 0 & 0 & 1 \\end{bmatrix} \\cdot \\begin{bmatrix} a & 0 & 0 \\\\ 0 & b & 0 \\\\ 0 & 0 & 1 \\end{bmatrix} \\cdot \\begin{bmatrix} x \\\\ y \\\\ 1 \\end{bmatrix} $$"
],
"text/plain": [
""
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Matrix support.\n",
"@latexify.function(reduce_assignments=True, use_math_symbols=True)\n",
"def transform(x, y, a, b, theta, s, t):\n",
" cos_t = math.cos(theta)\n",
" sin_t = math.sin(theta)\n",
" scale = np.array([[a, 0, 0], [0, b, 0], [0, 0, 1]])\n",
" rotate = np.array([[cos_t, -sin_t, 0], [sin_t, cos_t, 0], [0, 0, 1]])\n",
" move = np.array([[1, 0, s], [0, 1, t], [0, 0, 1]])\n",
" return move @ rotate @ scale @ np.array([[x], [y], [1]])\n",
"\n",
"transform\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 241
},
"id": "81OlPVWyGfWN",
"outputId": "48660400-a812-41e2-91ea-23e49ea20c7f"
},
"outputs": [
{
"data": {
"text/latex": [
"$ \\begin{array}{l} \\mathbf{function} \\ \\mathrm{fib}(x) \\\\ \\hspace{1em} \\mathbf{if} \\ x = 0 \\\\ \\hspace{2em} \\mathbf{return} \\ 0 \\\\ \\hspace{1em} \\mathbf{else} \\\\ \\hspace{2em} \\mathbf{if} \\ x = 1 \\\\ \\hspace{3em} \\mathbf{return} \\ 1 \\\\ \\hspace{2em} \\mathbf{else} \\\\ \\hspace{3em} \\mathbf{return} \\ \\mathrm{fib} \\mathopen{}\\left( x - 1 \\mathclose{}\\right) + \\mathrm{fib} \\mathopen{}\\left( x - 2 \\mathclose{}\\right) \\\\ \\hspace{2em} \\mathbf{end \\ if} \\\\ \\hspace{1em} \\mathbf{end \\ if} \\\\ \\mathbf{end \\ function} \\end{array} $"
],
"text/plain": [
""
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# latexify.algorithmic generates an algorithmic environment instead of an equation.\n",
"@latexify.algorithmic\n",
"def fib(x):\n",
" if x == 0:\n",
" return 0\n",
" elif x == 1:\n",
" return 1\n",
" else:\n",
" return fib(x-1) + fib(x-2)\n",
"\n",
"fib\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 261
},
"id": "kbw_1txkGfnX",
"outputId": "fdc58207-1c06-4d88-e249-b0b011bd98c0"
},
"outputs": [
{
"data": {
"text/latex": [
"$ \\begin{array}{l} \\mathbf{function} \\ \\mathrm{collatz}(x) \\\\ \\hspace{1em} n \\gets 0 \\\\ \\hspace{1em} \\mathbf{while} \\ x > 1 \\\\ \\hspace{2em} n \\gets n + 1 \\\\ \\hspace{2em} \\mathbf{if} \\ x \\mathbin{\\%} 2 = 0 \\\\ \\hspace{3em} x \\gets \\left\\lfloor\\frac{x}{2}\\right\\rfloor \\\\ \\hspace{2em} \\mathbf{else} \\\\ \\hspace{3em} x \\gets 3 x + 1 \\\\ \\hspace{2em} \\mathbf{end \\ if} \\\\ \\hspace{1em} \\mathbf{end \\ while} \\\\ \\hspace{1em} \\mathbf{return} \\ n \\\\ \\mathbf{end \\ function} \\end{array} $"
],
"text/plain": [
""
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Another example: latexify.algorithmic supports usual control flows.\n",
"@latexify.algorithmic\n",
"def collatz(x):\n",
" n = 0\n",
" while x > 1:\n",
" n = n + 1\n",
" if x % 2 == 0:\n",
" x = x // 2\n",
" else:\n",
" x = 3 * x + 1\n",
" return n\n",
"\n",
"collatz\n"
]
}
],
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = [
"hatchling",
]
build-backend = "hatchling.build"
[project]
name = "latexify-py"
description = "Generates LaTeX math description from Python functions."
readme = "README.md"
requires-python = ">=3.9, <3.14"
license = {text = "Apache Software License 2.0"}
authors = [
{name = "Yusuke Oda", email = "odashi@inspiredco.ai"}
]
keywords = [
"equation",
"latex",
"math",
"mathematics",
"tex",
]
classifiers = [
"Framework :: IPython",
"Framework :: Jupyter",
"License :: OSI Approved :: Apache Software License",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Topic :: Scientific/Engineering :: Mathematics",
"Topic :: Software Development :: Code Generators",
"Topic :: Text Processing :: Markup :: LaTeX",
]
dependencies = [
"dill>=0.3.2",
]
dynamic = [
"version"
]
[project.optional-dependencies]
dev = [
"build>=0.8",
"black>=24.3",
"flake8>=6.0",
"isort>=5.10",
"mypy>=1.9",
"notebook>=6.5.1",
"pyproject-flake8>=6.0",
"pytest>=7.1",
"twine>=4.0",
]
mypy = [
"mypy>=1.9",
"pytest>=7.1",
]
[project.urls]
Homepage = "https://github.com/google/latexify_py"
"Bug Tracker" = "https://github.com/google/latexify_py/issues"
[tool.hatch.build]
include = [
"*.py",
]
exclude = [
"*_test.py",
]
only-packages = true
[tool.hatch.build.targets.wheel]
packages = ["src/latexify"]
[tool.hatch.version]
path = "src/latexify/_version.py"
[tool.flake8]
max-line-length = 88
extend-ignore = "E203"
[tool.isort]
profile = "black"
================================================
FILE: src/integration_tests/__init__.py
================================================
"""Package integration_tests."""
import pytest
pytest.register_assert_rewrite("integration_tests.utils")
================================================
FILE: src/integration_tests/algorithmic_style_test.py
================================================
"""End-to-end test cases of algorithmic style."""
from __future__ import annotations
import textwrap
from integration_tests import integration_utils
def test_factorial() -> None:
def fact(n):
if n == 0:
return 1
else:
return n * fact(n - 1)
latex = textwrap.dedent(
r"""
\begin{algorithmic}
\Function{fact}{$n$}
\If{$n = 0$}
\State \Return $1$
\Else
\State \Return $n \cdot \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$
\EndIf
\EndFunction
\end{algorithmic}
""" # noqa: E501
).strip()
ipython_latex = (
r"\begin{array}{l}"
r" \mathbf{function} \ \mathrm{fact}(n) \\"
r" \hspace{1em} \mathbf{if} \ n = 0 \\"
r" \hspace{2em} \mathbf{return} \ 1 \\"
r" \hspace{1em} \mathbf{else} \\"
r" \hspace{2em}"
r" \mathbf{return} \ n \cdot"
r" \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right) \\"
r" \hspace{1em} \mathbf{end \ if} \\"
r" \mathbf{end \ function}"
r" \end{array}"
)
integration_utils.check_algorithm(fact, latex, ipython_latex)
def test_collatz() -> None:
def collatz(n):
iterations = 0
while n > 1:
if n % 2 == 0:
n = n // 2
else:
n = 3 * n + 1
iterations = iterations + 1
return iterations
latex = textwrap.dedent(
r"""
\begin{algorithmic}
\Function{collatz}{$n$}
\State $\mathrm{iterations} \gets 0$
\While{$n > 1$}
\If{$n \mathbin{\%} 2 = 0$}
\State $n \gets \left\lfloor\frac{n}{2}\right\rfloor$
\Else
\State $n \gets 3 n + 1$
\EndIf
\State $\mathrm{iterations} \gets \mathrm{iterations} + 1$
\EndWhile
\State \Return $\mathrm{iterations}$
\EndFunction
\end{algorithmic}
"""
).strip()
ipython_latex = (
r"\begin{array}{l}"
r" \mathbf{function} \ \mathrm{collatz}(n) \\"
r" \hspace{1em} \mathrm{iterations} \gets 0 \\"
r" \hspace{1em} \mathbf{while} \ n > 1 \\"
r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\"
r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\"
r" \hspace{2em} \mathbf{else} \\"
r" \hspace{3em} n \gets 3 n + 1 \\"
r" \hspace{2em} \mathbf{end \ if} \\"
r" \hspace{2em}"
r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\"
r" \hspace{1em} \mathbf{end \ while} \\"
r" \hspace{1em} \mathbf{return} \ \mathrm{iterations} \\"
r" \mathbf{end \ function}"
r" \end{array}"
)
integration_utils.check_algorithm(collatz, latex, ipython_latex)
================================================
FILE: src/integration_tests/function_expansion_test.py
================================================
"""End-to-end test cases of function expansion."""
from __future__ import annotations
import math
from integration_tests import integration_utils
def test_atan2() -> None:
def solve(x, y):
return math.atan2(y, x)
latex = (
r"\mathrm{solve}(x, y) ="
r" \arctan \mathopen{}\left( \frac{y}{x} \mathclose{}\right)"
)
integration_utils.check_function(solve, latex, expand_functions={"atan2"})
def test_atan2_nested() -> None:
def solve(x, y):
return math.atan2(math.exp(y), math.exp(x))
latex = (
r"\mathrm{solve}(x, y) ="
r" \arctan \mathopen{}\left( \frac{e^{y}}{e^{x}} \mathclose{}\right)"
)
integration_utils.check_function(solve, latex, expand_functions={"atan2", "exp"})
def test_exp() -> None:
def solve(x):
return math.exp(x)
latex = r"\mathrm{solve}(x) = e^{x}"
integration_utils.check_function(solve, latex, expand_functions={"exp"})
def test_exp_nested() -> None:
def solve(x):
return math.exp(math.exp(x))
latex = r"\mathrm{solve}(x) = e^{e^{x}}"
integration_utils.check_function(solve, latex, expand_functions={"exp"})
def test_exp2() -> None:
def solve(x):
return math.exp2(x)
latex = r"\mathrm{solve}(x) = 2^{x}"
integration_utils.check_function(solve, latex, expand_functions={"exp2"})
def test_exp2_nested() -> None:
def solve(x):
return math.exp2(math.exp2(x))
latex = r"\mathrm{solve}(x) = 2^{2^{x}}"
integration_utils.check_function(solve, latex, expand_functions={"exp2"})
def test_expm1() -> None:
def solve(x):
return math.expm1(x)
latex = r"\mathrm{solve}(x) = \exp x - 1"
integration_utils.check_function(solve, latex, expand_functions={"expm1"})
def test_expm1_nested() -> None:
def solve(x, y, z):
return math.expm1(math.pow(y, z))
latex = r"\mathrm{solve}(x, y, z) = e^{y^{z}} - 1"
integration_utils.check_function(
solve, latex, expand_functions={"expm1", "exp", "pow"}
)
def test_hypot_without_attribute() -> None:
from math import hypot
def solve(x, y, z):
return hypot(x, y, z)
latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }"
integration_utils.check_function(solve, latex, expand_functions={"hypot"})
def test_hypot() -> None:
def solve(x, y, z):
return math.hypot(x, y, z)
latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }"
integration_utils.check_function(solve, latex, expand_functions={"hypot"})
def test_hypot_nested() -> None:
def solve(a, b, x, y):
return math.hypot(math.hypot(a, b), x, y)
latex = (
r"\mathrm{solve}(a, b, x, y) ="
r" \sqrt{ \sqrt{ a^{2} + b^{2} }^{2} + x^{2} + y^{2} }"
)
integration_utils.check_function(solve, latex, expand_functions={"hypot"})
def test_log1p() -> None:
def solve(x):
return math.log1p(x)
latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + x \mathclose{}\right)"
integration_utils.check_function(solve, latex, expand_functions={"log1p"})
def test_log1p_nested() -> None:
def solve(x):
return math.log1p(math.exp(x))
latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + e^{x} \mathclose{}\right)"
integration_utils.check_function(solve, latex, expand_functions={"log1p", "exp"})
def test_pow_nested() -> None:
def solve(w, x, y, z):
return math.pow(math.pow(w, x), math.pow(y, z))
latex = (
r"\mathrm{solve}(w, x, y, z) = "
r"\mathopen{}\left( w^{x} \mathclose{}\right)^{y^{z}}"
)
integration_utils.check_function(solve, latex, expand_functions={"pow"})
def test_pow() -> None:
def solve(x, y):
return math.pow(x, y)
latex = r"\mathrm{solve}(x, y) = x^{y}"
integration_utils.check_function(solve, latex, expand_functions={"pow"})
================================================
FILE: src/integration_tests/integration_utils.py
================================================
"""Utilities for integration tests."""
from __future__ import annotations
from typing import Any, Callable
from latexify import frontend
def check_function(
fn: Callable[..., Any],
latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.
Args:
fn: Function to check.
latex: LaTeX form of `fn`.
**kwargs: Arguments passed to `frontend.function`.
"""
# Checks the syntax:
# @function
# def fn(...):
# ...
if not kwargs:
latexified = frontend.function(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"
# Checks the syntax:
# @function(**kwargs)
# def fn(...):
# ...
latexified = frontend.function(**kwargs)(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"
# Checks the syntax:
# def fn(...):
# ...
# latexified = function(fn, **kwargs)
latexified = frontend.function(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$"
def check_algorithm(
fn: Callable[..., Any],
latex: str,
ipython_latex: str,
**kwargs,
) -> None:
"""Helper to check if the obtained function has the expected LaTeX form.
Args:
fn: Function to check.
latex: LaTeX form of `fn`.
ipython_latex: IPython LaTeX form of `fn`
**kwargs: Arguments passed to `frontend.get_latex`.
"""
# Checks the syntax:
# @algorithmic
# def fn(...):
# ...
if not kwargs:
latexified = frontend.algorithmic(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"
# Checks the syntax:
# @algorithmic(**kwargs)
# def fn(...):
# ...
latexified = frontend.algorithmic(**kwargs)(fn)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"
# Checks the syntax:
# def fn(...):
# ...
# latexified = algorithmic(fn, **kwargs)
latexified = frontend.algorithmic(fn, **kwargs)
assert str(latexified) == latex
assert latexified._repr_latex_() == f"$ {ipython_latex} $"
================================================
FILE: src/integration_tests/regression_test.py
================================================
"""End-to-end test cases of function."""
from __future__ import annotations
import math
from integration_tests import integration_utils
def test_quadratic_solution() -> None:
def solve(a, b, c):
return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a)
latex = r"\mathrm{solve}(a, b, c) =" r" \frac{-b + \sqrt{ b^{2} - 4 a c }}{2 a}"
integration_utils.check_function(solve, latex)
def test_sinc() -> None:
def sinc(x):
if x == 0:
return 1
else:
return math.sin(x) / x
latex = (
r"\mathrm{sinc}(x) ="
r" \left\{ \begin{array}{ll}"
r" 1, & \mathrm{if} \ x = 0 \\"
r" \frac{\sin x}{x}, & \mathrm{otherwise}"
r" \end{array} \right."
)
integration_utils.check_function(sinc, latex)
def test_x_times_beta() -> None:
def xtimesbeta(x, beta):
return x * beta
latex_without_symbols = (
r"\mathrm{xtimesbeta}(x, \mathrm{beta}) = x \cdot \mathrm{beta}"
)
integration_utils.check_function(xtimesbeta, latex_without_symbols)
integration_utils.check_function(
xtimesbeta, latex_without_symbols, use_math_symbols=False
)
latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta"
integration_utils.check_function(
xtimesbeta, latex_with_symbols, use_math_symbols=True
)
def test_sum_with_limit_1arg() -> None:
def sum_with_limit(n):
return sum(i**2 for i in range(n))
latex = (
r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n - 1}"
r" \mathopen{}\left({i^{2}}\mathclose{}\right)"
)
integration_utils.check_function(sum_with_limit, latex)
def test_sum_with_limit_2args() -> None:
def sum_with_limit(a, n):
return sum(i**2 for i in range(a, n))
latex = (
r"\mathrm{sum\_with\_limit}(a, n) = \sum_{i = a}^{n - 1}"
r" \mathopen{}\left({i^{2}}\mathclose{}\right)"
)
integration_utils.check_function(sum_with_limit, latex)
def test_sum_with_reducible_limit() -> None:
def sum_with_limit(n):
return sum(i for i in range(n + 1))
latex = (
r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n}"
r" \mathopen{}\left({i}\mathclose{}\right)"
)
integration_utils.check_function(sum_with_limit, latex)
def test_sum_with_irreducible_limit() -> None:
def sum_with_limit(n):
return sum(i for i in range(n * 3))
latex = (
r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n \cdot 3 - 1}"
r" \mathopen{}\left({i}\mathclose{}\right)"
)
integration_utils.check_function(sum_with_limit, latex)
def test_prod_with_limit_1arg() -> None:
def prod_with_limit(n):
return math.prod(i**2 for i in range(n))
latex = (
r"\mathrm{prod\_with\_limit}(n) ="
r" \prod_{i = 0}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)"
)
integration_utils.check_function(prod_with_limit, latex)
def test_prod_with_limit_2args() -> None:
def prod_with_limit(a, n):
return math.prod(i**2 for i in range(a, n))
latex = (
r"\mathrm{prod\_with\_limit}(a, n) ="
r" \prod_{i = a}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)"
)
integration_utils.check_function(prod_with_limit, latex)
def test_prod_with_reducible_limits() -> None:
def prod_with_limit(n):
return math.prod(i for i in range(n - 1))
latex = (
r"\mathrm{prod\_with\_limit}(n) ="
r" \prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)"
)
integration_utils.check_function(prod_with_limit, latex)
def test_prod_with_irreducible_limit() -> None:
def prod_with_limit(n):
return math.prod(i for i in range(n * 3))
latex = (
r"\mathrm{prod\_with\_limit}(n) = "
r"\prod_{i = 0}^{n \cdot 3 - 1} \mathopen{}\left({i}\mathclose{}\right)"
)
integration_utils.check_function(prod_with_limit, latex)
def test_nested_function() -> None:
def nested(x):
return 3 * x
integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 x")
def test_double_nested_function() -> None:
def nested(x):
def inner(y):
return x * y
return inner
integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x y")
def test_reduce_assignments() -> None:
def f(x):
a = x + x
return 3 * a
integration_utils.check_function(
f,
r"\begin{array}{l} a = x + x \\ f(x) = 3 a \end{array}",
)
integration_utils.check_function(
f,
r"f(x) = 3 \mathopen{}\left( x + x \mathclose{}\right)",
reduce_assignments=True,
)
def test_reduce_assignments_double() -> None:
def f(x):
a = x**2
b = a + a
return 3 * b
latex_without_option = (
r"\begin{array}{l}"
r" a = x^{2} \\"
r" b = a + a \\"
r" f(x) = 3 b"
r" \end{array}"
)
integration_utils.check_function(f, latex_without_option)
integration_utils.check_function(f, latex_without_option, reduce_assignments=False)
integration_utils.check_function(
f,
r"f(x) = 3 \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)",
reduce_assignments=True,
)
def test_reduce_assignments_with_if() -> None:
def sigmoid(x):
p = 1 / (1 + math.exp(-x))
n = math.exp(x) / (math.exp(x) + 1)
if x > 0:
return p
else:
return n
integration_utils.check_function(
sigmoid,
(
r"\mathrm{sigmoid}(x) = \left\{ \begin{array}{ll}"
r" \frac{1}{1 + \exp \mathopen{}\left( -x \mathclose{}\right)}, &"
r" \mathrm{if} \ x > 0 \\"
r" \frac{\exp x}{\exp x + 1}, &"
r" \mathrm{otherwise}"
r" \end{array} \right."
),
reduce_assignments=True,
)
def test_sub_bracket() -> None:
def solve(a, b):
return ((a + b) - b) / (a - b) - (a + b) - (a - b) - (a * b)
latex = (
r"\mathrm{solve}(a, b) ="
r" \frac{a + b - b}{a - b} - \mathopen{}\left("
r" a + b \mathclose{}\right) - \mathopen{}\left("
r" a - b \mathclose{}\right) - a b"
)
integration_utils.check_function(solve, latex)
def test_docstring_allowed() -> None:
def solve(x):
"""The identity function."""
return x
latex = r"\mathrm{solve}(x) = x"
integration_utils.check_function(solve, latex)
def test_multiple_constants_allowed() -> None:
def solve(x):
"""The identity function."""
123
True
return x
latex = r"\mathrm{solve}(x) = x"
integration_utils.check_function(solve, latex)
================================================
FILE: src/latexify/__init__.py
================================================
"""Latexify root package."""
try:
from latexify import _version
__version__ = _version.__version__
except Exception:
__version__ = ""
from latexify import frontend, generate_latex
Style = generate_latex.Style
get_latex = generate_latex.get_latex
algorithmic = frontend.algorithmic
expression = frontend.expression
function = frontend.function
================================================
FILE: src/latexify/_version.py
================================================
"""Version specifier.
DON'T TOUCH THIS FILE.
This file is replaced during the release process.
"""
__version__ = "0.0.0a0"
================================================
FILE: src/latexify/analyzers.py
================================================
"""Analyzer functions for specific subtrees."""
from __future__ import annotations
import ast
import dataclasses
import sys
from latexify import ast_utils, exceptions
@dataclasses.dataclass(frozen=True, eq=False)
class RangeInfo:
"""Information of the range function."""
# Argument subtrees. These arguments could be shallow copies of the original
# subtree.
start: ast.expr
stop: ast.expr
step: ast.expr
# Integer representation of each argument, when it is possible.
start_int: int | None
stop_int: int | None
step_int: int | None
def analyze_range(node: ast.Call) -> RangeInfo:
"""Obtains RangeInfo from a Call subtree.
Args:
node: Subtree to be analyzed.
Returns:
RangeInfo extracted from `node`.
Raises:
LatexifySyntaxError: Analysis failed.
"""
if not (
isinstance(node.func, ast.Name)
and node.func.id == "range"
and 1 <= len(node.args) <= 3
):
raise exceptions.LatexifySyntaxError("Unsupported AST for analyze_range.")
num_args = len(node.args)
if num_args == 1:
start = ast_utils.make_constant(0)
stop = node.args[0]
step = ast_utils.make_constant(1)
else:
start = node.args[0]
stop = node.args[1]
step = node.args[2] if num_args == 3 else ast_utils.make_constant(1)
return RangeInfo(
start=start,
stop=stop,
step=step,
start_int=ast_utils.extract_int_or_none(start),
stop_int=ast_utils.extract_int_or_none(stop),
step_int=ast_utils.extract_int_or_none(step),
)
def reduce_stop_parameter(node: ast.expr) -> ast.expr:
"""Adjusts the stop expression of the range.
This function tries to convert the syntax as follows:
* n + 1 --> n
* n + 2 --> n + 1
* n - 1 --> n - 2
Args:
node: The target expression.
Returns:
Converted expression.
"""
if not (isinstance(node, ast.BinOp) and isinstance(node.op, (ast.Add, ast.Sub))):
return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1))
# Treatment for Python 3.7.
rhs = (
ast.Constant(value=node.right.n)
if sys.version_info.minor < 8 and isinstance(node.right, ast.Num)
else node.right
)
if not isinstance(rhs, ast.Constant):
return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1))
shift = 1 if isinstance(node.op, ast.Add) else -1
return (
node.left
if rhs.value == shift
else ast.BinOp(
left=node.left,
op=node.op,
right=ast_utils.make_constant(value=rhs.value - shift),
)
)
================================================
FILE: src/latexify/analyzers_test.py
================================================
"""Tests for latexify.analyzers."""
from __future__ import annotations
import ast
import pytest
from latexify import analyzers, ast_utils, exceptions, test_utils
@pytest.mark.parametrize(
"code,start,stop,step,start_int,stop_int,step_int",
[
(
"range(x)",
ast.Constant(value=0),
ast.Name(id="x", ctx=ast.Load()),
ast.Constant(value=1),
0,
None,
1,
),
(
"range(123)",
ast.Constant(value=0),
ast.Constant(value=123),
ast.Constant(value=1),
0,
123,
1,
),
(
"range(x, y)",
ast.Name(id="x", ctx=ast.Load()),
ast.Name(id="y", ctx=ast.Load()),
ast.Constant(value=1),
None,
None,
1,
),
(
"range(123, y)",
ast.Constant(value=123),
ast.Name(id="y", ctx=ast.Load()),
ast.Constant(value=1),
123,
None,
1,
),
(
"range(x, 123)",
ast.Name(id="x", ctx=ast.Load()),
ast.Constant(value=123),
ast.Constant(value=1),
None,
123,
1,
),
(
"range(x, y, z)",
ast.Name(id="x", ctx=ast.Load()),
ast.Name(id="y", ctx=ast.Load()),
ast.Name(id="z", ctx=ast.Load()),
None,
None,
None,
),
(
"range(123, y, z)",
ast.Constant(value=123),
ast.Name(id="y", ctx=ast.Load()),
ast.Name(id="z", ctx=ast.Load()),
123,
None,
None,
),
(
"range(x, 123, z)",
ast.Name(id="x", ctx=ast.Load()),
ast.Constant(value=123),
ast.Name(id="z", ctx=ast.Load()),
None,
123,
None,
),
(
"range(x, y, 123)",
ast.Name(id="x", ctx=ast.Load()),
ast.Name(id="y", ctx=ast.Load()),
ast.Constant(value=123),
None,
None,
123,
),
],
)
def test_analyze_range(
code: str,
start: ast.expr,
stop: ast.expr,
step: ast.expr,
start_int: int | None,
stop_int: int | None,
step_int: int | None,
) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)
info = analyzers.analyze_range(node)
test_utils.assert_ast_equal(observed=info.start, expected=start)
test_utils.assert_ast_equal(observed=info.stop, expected=stop)
if step is not None:
test_utils.assert_ast_equal(observed=info.step, expected=step)
else:
assert info.step is None
def check_int(observed: int | None, expected: int | None) -> None:
if expected is not None:
assert observed == expected
else:
assert observed is None
check_int(observed=info.start_int, expected=start_int)
check_int(observed=info.stop_int, expected=stop_int)
check_int(observed=info.step_int, expected=step_int)
@pytest.mark.parametrize(
"code",
[
# Not a direct call
"__builtins__.range(x)",
'getattr(__builtins__, "range")(x)',
# Unsupported functions
"f(x)",
"iter(range(x))",
# Range with invalid arguments
"range()",
"range(x, y, z, w)",
],
)
def test_analyze_range_invalid(code: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)
with pytest.raises(
exceptions.LatexifySyntaxError, match=r"^Unsupported AST for analyze_range\.$"
):
analyzers.analyze_range(node)
@pytest.mark.parametrize(
"before,after",
[
("n + 1", "n"),
("n + 2", "n + 1"),
("n - (-1)", "n - (-1) - 1"),
("n - 1", "n - 2"),
("1 * 2", "1 * 2 - 1"),
],
)
def test_reduce_stop_parameter(before: str, after: str) -> None:
test_utils.assert_ast_equal(
analyzers.reduce_stop_parameter(ast_utils.parse_expr(before)),
ast_utils.parse_expr(after),
)
================================================
FILE: src/latexify/ast_utils.py
================================================
"""Utilities to generate AST nodes."""
from __future__ import annotations
import ast
import sys
from typing import Any
def parse_expr(code: str) -> ast.expr:
"""Parses given Python expression.
Args:
code: Python expression to parse.
Returns:
ast.expr corresponding to `code`.
"""
return ast.parse(code, mode="eval").body
def make_name(id: str) -> ast.Name:
"""Generates a new Name node.
Args:
id: Name of the node.
Returns:
Generated ast.Name.
"""
return ast.Name(id=id, ctx=ast.Load())
def make_attribute(value: ast.expr, attr: str):
"""Generates a new Attribute node.
Args:
value: Parent value.
attr: Attribute name.
Returns:
Generated ast.Attribute.
"""
return ast.Attribute(value=value, attr=attr, ctx=ast.Load())
def make_constant(value: Any) -> ast.expr:
"""Generates a new Constant node.
Args:
value: Value of the node.
Returns:
Generated ast.Constant or its equivalent.
Raises:
ValueError: Unsupported value type.
"""
if (
value is None
or value is ...
or isinstance(value, (bool, int, float, complex, str, bytes))
):
return ast.Constant(value=value)
raise ValueError(f"Unsupported type to generate Constant: {type(value).__name__}")
def is_constant(node: ast.AST) -> bool:
"""Checks if the node is a constant.
Args:
node: The node to examine.
Returns:
True if the node is a constant, False otherwise.
"""
return isinstance(node, ast.Constant)
def is_str(node: ast.AST) -> bool:
"""Checks if the node is a str constant.
Args:
node: The node to examine.
Returns:
True if the node is a str constant, False otherwise.
"""
if sys.version_info.minor < 8 and isinstance(node, ast.Str):
return True
return isinstance(node, ast.Constant) and isinstance(node.value, str)
def extract_int_or_none(node: ast.expr) -> int | None:
"""Extracts int constant from the given Constant node.
Args:
node: ast.Constant or its equivalent representing an int value.
Returns:
Extracted int value, or None if extraction failed.
"""
if (
isinstance(node, ast.Constant)
and isinstance(node.value, int)
and not isinstance(node.value, bool)
):
return node.value
return None
def extract_int(node: ast.expr) -> int:
"""Extracts int constant from the given Constant node.
Args:
node: ast.Constant or its equivalent representing an int value.
Returns:
Extracted int value.
Raises:
ValueError: Not a subtree containing an int value.
"""
value = extract_int_or_none(node)
if value is None:
raise ValueError(f"Unsupported node to extract int: {type(node).__name__}")
return value
def extract_function_name_or_none(node: ast.Call) -> str | None:
"""Extracts function name from the given Call node.
Args:
node: ast.Call.
Returns:
Extracted function name, or None if not found.
"""
if isinstance(node.func, ast.Name):
return node.func.id
if isinstance(node.func, ast.Attribute):
return node.func.attr
return None
def create_function_def(
name,
args,
body,
decorator_list,
returns=None,
type_comment=None,
type_params=None,
lineno=None,
col_offset=None,
end_lineno=None,
end_col_offset=None,
) -> ast.FunctionDef:
"""Creates a FunctionDef node.
This function generates an `ast.FunctionDef` node, optionally removing
the `type_params` keyword argument for Python versions below 3.12.
Args:
name: Name of the function.
args: Arguments of the function.
body: Body of the function.
decorator_list: List of decorators.
returns: Return type of the function.
type_comment: Type comment of the function.
type_params: Type parameters of the function.
lineno: Line number of the function definition.
col_offset: Column offset of the function definition.
end_lineno: End line number of the function definition.
end_col_offset: End column offset of the function definition.
Returns:
ast.FunctionDef: The generated FunctionDef node.
"""
if sys.version_info.minor < 12:
return ast.FunctionDef(
name=name,
args=args,
body=body,
decorator_list=decorator_list,
returns=returns,
type_comment=type_comment,
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
) # type: ignore
return ast.FunctionDef(
name=name,
args=args,
body=body,
decorator_list=decorator_list,
returns=returns,
type_comment=type_comment,
type_params=type_params,
lineno=lineno,
col_offset=col_offset,
end_lineno=end_lineno,
end_col_offset=end_col_offset,
) # type: ignore
================================================
FILE: src/latexify/ast_utils_test.py
================================================
"""Tests for latexify.ast_utils."""
from __future__ import annotations
import ast
import sys
from typing import Any
import pytest
from latexify import ast_utils, test_utils
def test_parse_expr() -> None:
test_utils.assert_ast_equal(
ast_utils.parse_expr("a + b"),
ast.BinOp(
left=ast_utils.make_name("a"),
op=ast.Add(),
right=ast_utils.make_name("b"),
),
)
def test_make_name() -> None:
test_utils.assert_ast_equal(
ast_utils.make_name("foo"), ast.Name(id="foo", ctx=ast.Load())
)
def test_make_attribute() -> None:
test_utils.assert_ast_equal(
ast_utils.make_attribute(ast_utils.make_name("foo"), "bar"),
ast.Attribute(ast.Name(id="foo", ctx=ast.Load()), attr="bar", ctx=ast.Load()),
)
@pytest.mark.parametrize(
"value,expected",
[
(None, ast.Constant(value=None)),
(False, ast.Constant(value=False)),
(True, ast.Constant(value=True)),
(..., ast.Constant(value=...)),
(123, ast.Constant(value=123)),
(4.5, ast.Constant(value=4.5)),
(6 + 7j, ast.Constant(value=6 + 7j)),
("foo", ast.Constant(value="foo")),
(b"bar", ast.Constant(value=b"bar")),
],
)
def test_make_constant(value: Any, expected: ast.Constant) -> None:
test_utils.assert_ast_equal(
observed=ast_utils.make_constant(value),
expected=expected,
)
def test_make_constant_invalid() -> None:
with pytest.raises(ValueError, match=r"^Unsupported type to generate"):
ast_utils.make_constant(object())
@pytest.mark.parametrize(
"value,expected",
[
(ast.Constant(value="foo"), True),
(ast.Expr(value=ast.Constant(value=123)), False),
(ast.Global(names=["bar"]), False),
],
)
def test_is_constant(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_constant(value) is expected
@pytest.mark.parametrize(
"value,expected",
[
(ast.Constant(value=123), False),
(ast.Constant(value="foo"), True),
(ast.Expr(value=ast.Constant(value="foo")), False),
(ast.Global(names=["foo"]), False),
],
)
def test_is_str(value: ast.AST, expected: bool) -> None:
assert ast_utils.is_str(value) is expected
def test_extract_int_or_none() -> None:
assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123
assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0
assert ast_utils.extract_int_or_none(ast_utils.make_constant(123)) == 123
def test_extract_int_or_none_invalid() -> None:
# Not a Constant node with int.
assert ast_utils.extract_int_or_none(ast_utils.make_constant(None)) is None
assert ast_utils.extract_int_or_none(ast_utils.make_constant(True)) is None
assert ast_utils.extract_int_or_none(ast_utils.make_constant(...)) is None
assert ast_utils.extract_int_or_none(ast_utils.make_constant(123.0)) is None
assert ast_utils.extract_int_or_none(ast_utils.make_constant(4 + 5j)) is None
assert ast_utils.extract_int_or_none(ast_utils.make_constant("123")) is None
assert ast_utils.extract_int_or_none(ast_utils.make_constant(b"123")) is None
def test_extract_int() -> None:
assert ast_utils.extract_int(ast_utils.make_constant(-123)) == -123
assert ast_utils.extract_int(ast_utils.make_constant(0)) == 0
assert ast_utils.extract_int(ast_utils.make_constant(123)) == 123
def test_extract_int_invalid() -> None:
# Not a Constant node with int.
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant(None))
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant(True))
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant(...))
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant(123.0))
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant(4 + 5j))
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant("123"))
with pytest.raises(ValueError, match=r"^Unsupported node to extract int"):
ast_utils.extract_int(ast_utils.make_constant(b"123"))
@pytest.mark.parametrize(
"value,expected",
[
(
ast.Call(
func=ast.Name(id="hypot", ctx=ast.Load()),
args=[],
keywords=[],
),
"hypot",
),
(
ast.Call(
func=ast.Attribute(
value=ast.Name(id="math", ctx=ast.Load()),
attr="hypot",
ctx=ast.Load(),
),
args=[],
keywords=[],
),
"hypot",
),
(
ast.Call(
func=ast.Call(
func=ast.Name(id="foo", ctx=ast.Load()), args=[], keywords=[]
),
args=[],
keywords=[],
),
None,
),
],
)
def test_extract_function_name_or_none(value: ast.Call, expected: str | None) -> None:
assert ast_utils.extract_function_name_or_none(value) == expected
def test_create_function_def() -> None:
expected_args = ast.arguments(
posonlyargs=[],
args=[ast.arg(arg="x")],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[],
)
kwargs = {
"name": "test_func",
"args": expected_args,
"body": [ast.Return(value=ast.Name(id="x", ctx=ast.Load()))],
"decorator_list": [],
"returns": None,
"type_comment": None,
"lineno": 1,
"col_offset": 0,
"end_lineno": 2,
"end_col_offset": 0,
}
if sys.version_info.minor >= 12:
kwargs["type_params"] = []
func_def = ast_utils.create_function_def(**kwargs)
assert isinstance(func_def, ast.FunctionDef)
assert func_def.name == "test_func"
assert func_def.args.posonlyargs == expected_args.posonlyargs
assert func_def.args.args == expected_args.args
assert func_def.args.kwonlyargs == expected_args.kwonlyargs
assert func_def.args.kw_defaults == expected_args.kw_defaults
assert func_def.args.defaults == expected_args.defaults
================================================
FILE: src/latexify/codegen/__init__.py
================================================
"""Package latexify.codegen."""
from latexify.codegen import algorithmic_codegen, expression_codegen, function_codegen
AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen
ExpressionCodegen = expression_codegen.ExpressionCodegen
FunctionCodegen = function_codegen.FunctionCodegen
IPythonAlgorithmicCodegen = algorithmic_codegen.IPythonAlgorithmicCodegen
================================================
FILE: src/latexify/codegen/algorithmic_codegen.py
================================================
"""Codegen for single algorithms."""
from __future__ import annotations
import ast
import contextlib
from collections.abc import Generator
from latexify import exceptions
from latexify.codegen import expression_codegen, identifier_converter
class AlgorithmicCodegen(ast.NodeVisitor):
"""Codegen for single algorithms.
This codegen works for Module with single FunctionDef node to generate a single
LaTeX expression of the given algorithm.
"""
_SPACES_PER_INDENT = 4
_identifier_converter: identifier_converter.IdentifierConverter
_indent_level: int
def __init__(
self,
*,
use_math_symbols: bool = False,
use_set_symbols: bool = False,
escape_underscores: bool = True,
) -> None:
"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols,
use_set_symbols=use_set_symbols,
escape_underscores=escape_underscores,
)
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols,
use_mathrm=False,
escape_underscores=escape_underscores,
)
self._indent_level = 0
def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)
def visit_Assign(self, node: ast.Assign) -> str:
"""Visit an Assign node."""
operands: list[str] = [
self._expression_codegen.visit(target) for target in node.targets
]
operands.append(self._expression_codegen.visit(node.value))
operands_latex = r" \gets ".join(operands)
return self._add_indent(rf"\State ${operands_latex}$")
def visit_Expr(self, node: ast.Expr) -> str:
"""Visit an Expr node."""
return self._add_indent(
rf"\State ${self._expression_codegen.visit(node.value)}$"
)
def visit_For(self, node: ast.For) -> str:
"""Visit a For node."""
if len(node.orelse) != 0:
raise exceptions.LatexifyNotSupportedError(
"For statement with the else clause is not supported"
)
target_latex = self._expression_codegen.visit(node.target)
iter_latex = self._expression_codegen.visit(node.iter)
with self._increment_level():
body_latex = "\n".join(self.visit(stmt) for stmt in node.body)
return (
self._add_indent(f"\\For{{${target_latex} \\in {iter_latex}$}}\n")
+ f"{body_latex}\n"
+ self._add_indent("\\EndFor")
)
# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
name_latex = self._identifier_converter.convert(node.name)[0]
# Arguments
arg_strs = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
]
latex = self._add_indent("\\begin{algorithmic}\n")
with self._increment_level():
latex += self._add_indent(
f"\\Function{{{name_latex}}}{{${', '.join(arg_strs)}$}}\n"
)
with self._increment_level():
# Body
body_strs: list[str] = [self.visit(stmt) for stmt in node.body]
body_latex = "\n".join(body_strs)
latex += f"{body_latex}\n"
latex += self._add_indent("\\EndFunction\n")
return latex + self._add_indent(r"\end{algorithmic}")
# TODO(ZibingZhang): support \ELSIF
def visit_If(self, node: ast.If) -> str:
"""Visit an If node."""
cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = "\n".join(self.visit(stmt) for stmt in node.body)
latex = self._add_indent(f"\\If{{${cond_latex}$}}\n" + body_latex)
if node.orelse:
latex += "\n" + self._add_indent("\\Else\n")
with self._increment_level():
latex += "\n".join(self.visit(stmt) for stmt in node.orelse)
return f"{latex}\n" + self._add_indent(r"\EndIf")
def visit_Module(self, node: ast.Module) -> str:
"""Visit a Module node."""
return self.visit(node.body[0])
def visit_Return(self, node: ast.Return) -> str:
"""Visit a Return node."""
return (
self._add_indent(
rf"\State \Return ${self._expression_codegen.visit(node.value)}$"
)
if node.value is not None
else self._add_indent(r"\State \Return")
)
def visit_While(self, node: ast.While) -> str:
"""Visit a While node."""
if node.orelse:
raise exceptions.LatexifyNotSupportedError(
"While statement with the else clause is not supported"
)
cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = "\n".join(self.visit(stmt) for stmt in node.body)
return (
self._add_indent(f"\\While{{${cond_latex}$}}\n")
+ f"{body_latex}\n"
+ self._add_indent(r"\EndWhile")
)
def visit_Pass(self, node: ast.Pass) -> str:
"""Visit a Pass node."""
return self._add_indent(r"\State $\mathbf{pass}$")
def visit_Break(self, node: ast.Break) -> str:
"""Visit a Break node."""
return self._add_indent(r"\State $\mathbf{break}$")
def visit_Continue(self, node: ast.Continue) -> str:
"""Visit a Continue node."""
return self._add_indent(r"\State $\mathbf{continue}$")
@contextlib.contextmanager
def _increment_level(self) -> Generator[None, None, None]:
"""Context manager controlling indent level."""
self._indent_level += 1
yield
self._indent_level -= 1
def _add_indent(self, line: str) -> str:
"""Adds an indent before the line.
Args:
line: The line to add an indent to.
"""
return self._indent_level * self._SPACES_PER_INDENT * " " + line
class IPythonAlgorithmicCodegen(ast.NodeVisitor):
"""Codegen for single algorithms targeting IPython.
This codegen works for Module with single FunctionDef node to generate a single
LaTeX expression of the given algorithm.
"""
_EM_PER_INDENT = 1
_LINE_BREAK = r" \\ "
_identifier_converter: identifier_converter.IdentifierConverter
_indent_level: int
def __init__(
self,
*,
use_math_symbols: bool = False,
use_set_symbols: bool = False,
escape_underscores: bool = True,
) -> None:
"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols,
use_set_symbols=use_set_symbols,
escape_underscores=escape_underscores,
)
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols, escape_underscores=escape_underscores
)
self._indent_level = 0
def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)
def visit_Assign(self, node: ast.Assign) -> str:
"""Visit an Assign node."""
operands: list[str] = [
self._expression_codegen.visit(target) for target in node.targets
]
operands.append(self._expression_codegen.visit(node.value))
operands_latex = r" \gets ".join(operands)
return self._add_indent(operands_latex)
def visit_Expr(self, node: ast.Expr) -> str:
"""Visit an Expr node."""
return self._add_indent(self._expression_codegen.visit(node.value))
def visit_For(self, node: ast.For) -> str:
"""Visit a For node."""
if len(node.orelse) != 0:
raise exceptions.LatexifyNotSupportedError(
"For statement with the else clause is not supported"
)
target_latex = self._expression_codegen.visit(node.target)
iter_latex = self._expression_codegen.visit(node.iter)
with self._increment_level():
body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body)
return (
self._add_indent(r"\mathbf{for}")
+ rf" \ {target_latex} \in {iter_latex} \ \mathbf{{do}}{self._LINE_BREAK}"
+ f"{body_latex}{self._LINE_BREAK}"
+ self._add_indent(r"\mathbf{end \ for}")
)
# TODO(ZibingZhang): support nested functions
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
name_latex = self._identifier_converter.convert(node.name)[0]
# Arguments
args_latex = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
]
# Body
with self._increment_level():
body_stmts_latex: list[str] = [self.visit(stmt) for stmt in node.body]
body_latex = self._LINE_BREAK.join(body_stmts_latex)
return (
r"\begin{array}{l} "
+ self._add_indent(r"\mathbf{function}")
+ rf" \ {name_latex}({', '.join(args_latex)})"
+ f"{self._LINE_BREAK}{body_latex}{self._LINE_BREAK}"
+ self._add_indent(r"\mathbf{end \ function}")
+ r" \end{array}"
)
# TODO(ZibingZhang): support \ELSIF
def visit_If(self, node: ast.If) -> str:
"""Visit an If node."""
cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body)
latex = self._add_indent(
rf"\mathbf{{if}} \ {cond_latex}{self._LINE_BREAK}{body_latex}"
)
if node.orelse:
latex += self._LINE_BREAK + self._add_indent(r"\mathbf{else} \\ ")
with self._increment_level():
latex += self._LINE_BREAK.join(self.visit(stmt) for stmt in node.orelse)
return latex + self._LINE_BREAK + self._add_indent(r"\mathbf{end \ if}")
def visit_Module(self, node: ast.Module) -> str:
"""Visit a Module node."""
return self.visit(node.body[0])
def visit_Return(self, node: ast.Return) -> str:
"""Visit a Return node."""
return (
self._add_indent(r"\mathbf{return} \ ")
+ self._expression_codegen.visit(node.value)
if node.value is not None
else self._add_indent(r"\mathbf{return}")
)
def visit_While(self, node: ast.While) -> str:
"""Visit a While node."""
if node.orelse:
raise exceptions.LatexifyNotSupportedError(
"While statement with the else clause is not supported"
)
cond_latex = self._expression_codegen.visit(node.test)
with self._increment_level():
body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body)
return (
self._add_indent(r"\mathbf{while} \ ")
+ f"{cond_latex}{self._LINE_BREAK}{body_latex}{self._LINE_BREAK}"
+ self._add_indent(r"\mathbf{end \ while}")
)
def visit_Pass(self, node: ast.Pass) -> str:
"""Visit a Pass node."""
return self._add_indent(r"\mathbf{pass}")
def visit_Break(self, node: ast.Break) -> str:
"""Visit a Break node."""
return self._add_indent(r"\mathbf{break}")
def visit_Continue(self, node: ast.Continue) -> str:
"""Visit a Continue node."""
return self._add_indent(r"\mathbf{continue}")
@contextlib.contextmanager
def _increment_level(self) -> Generator[None, None, None]:
"""Context manager controlling indent level."""
self._indent_level += 1
yield
self._indent_level -= 1
def _add_indent(self, line: str) -> str:
"""Adds an indent before the line.
Args:
line: The line to add an indent to.
"""
return (
rf"\hspace{{{self._indent_level * self._EM_PER_INDENT}em}} {line}"
if self._indent_level > 0
else line
)
================================================
FILE: src/latexify/codegen/algorithmic_codegen_test.py
================================================
"""Tests for latexify.codegen.algorithmic_codegen."""
from __future__ import annotations
import ast
import textwrap
import pytest
from latexify import exceptions
from latexify.codegen import algorithmic_codegen
def test_generic_visit() -> None:
class UnknownNode(ast.AST):
pass
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match=r"^Unsupported AST: UnknownNode$",
):
algorithmic_codegen.AlgorithmicCodegen().visit(UnknownNode())
@pytest.mark.parametrize(
"code,latex",
[
(
"x = 3",
r"\State $x \gets 3$",
),
(
"a = b = 0",
r"\State $a \gets b \gets 0$",
),
],
)
def test_visit_assign(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.Assign)
assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
(
"for i in {1}: x = i",
r"""
\For{$i \in \mathopen{}\left\{ 1 \mathclose{}\right\}$}
\State $x \gets i$
\EndFor
""",
),
],
)
def test_visit_for(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.For)
assert (
algorithmic_codegen.AlgorithmicCodegen().visit(node)
== textwrap.dedent(latex).strip()
)
@pytest.mark.parametrize(
"code,latex",
[
(
"def f(x): return x",
r"""
\begin{algorithmic}
\Function{f}{$x$}
\State \Return $x$
\EndFunction
\end{algorithmic}
""",
),
(
"def xyz(a, b, c): return 3",
r"""
\begin{algorithmic}
\Function{xyz}{$a, b, c$}
\State \Return $3$
\EndFunction
\end{algorithmic}
""",
),
],
)
def test_visit_functiondef(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.FunctionDef)
assert (
algorithmic_codegen.AlgorithmicCodegen().visit(node)
== textwrap.dedent(latex).strip()
)
@pytest.mark.parametrize(
"code,latex",
[
(
"if x < y: return x",
r"""
\If{$x < y$}
\State \Return $x$
\EndIf
""",
),
(
"if True: x\nelse: y",
r"""
\If{$\mathrm{True}$}
\State $x$
\Else
\State $y$
\EndIf
""",
),
],
)
def test_visit_if(code: str, latex: str) -> None:
node = ast.parse(code).body[0]
assert isinstance(node, ast.If)
assert (
algorithmic_codegen.AlgorithmicCodegen().visit(node)
== textwrap.dedent(latex).strip()
)
@pytest.mark.parametrize(
"code,latex",
[
(
"return x + y",
r"\State \Return $x + y$",
),
(
"return",
r"\State \Return",
),
],
)
def test_visit_return(code: str, latex: str) -> None:
node = ast.parse(code).body[0]
assert isinstance(node, ast.Return)
assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
(
"while x < y: x = x + 1",
r"""
\While{$x < y$}
\State $x \gets x + 1$
\EndWhile
""",
)
],
)
def test_visit_while(code: str, latex: str) -> None:
node = ast.parse(code).body[0]
assert isinstance(node, ast.While)
assert (
algorithmic_codegen.AlgorithmicCodegen().visit(node)
== textwrap.dedent(latex).strip()
)
def test_visit_while_with_else() -> None:
node = ast.parse(
textwrap.dedent(
"""
while True:
x = x
else:
x = y
"""
)
).body[0]
assert isinstance(node, ast.While)
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match="^While statement with the else clause is not supported$",
):
algorithmic_codegen.AlgorithmicCodegen().visit(node)
def test_visit_pass() -> None:
node = ast.parse("pass").body[0]
assert isinstance(node, ast.Pass)
assert (
algorithmic_codegen.AlgorithmicCodegen().visit(node)
== r"\State $\mathbf{pass}$"
)
def test_visit_break() -> None:
node = ast.parse("break").body[0]
assert isinstance(node, ast.Break)
assert (
algorithmic_codegen.AlgorithmicCodegen().visit(node)
== r"\State $\mathbf{break}$"
)
def test_visit_continue() -> None:
node = ast.parse("continue").body[0]
assert isinstance(node, ast.Continue)
assert (
algorithmic_codegen.AlgorithmicCodegen().visit(node)
== r"\State $\mathbf{continue}$"
)
@pytest.mark.parametrize(
"code,latex",
[
("x = 3", r"x \gets 3"),
("a = b = 0", r"a \gets b \gets 0"),
],
)
def test_visit_assign_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.Assign)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
(
"for i in {1}: x = i",
(
r"\mathbf{for} \ i \in \mathopen{}\left\{ 1 \mathclose{}\right\}"
r" \ \mathbf{do} \\"
r" \hspace{1em} x \gets i \\"
r" \mathbf{end \ for}"
),
),
],
)
def test_visit_for_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.For)
assert (
algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node)
== textwrap.dedent(latex).strip()
)
@pytest.mark.parametrize(
"code,latex",
[
(
"def f(x): return x",
(
r"\begin{array}{l}"
r" \mathbf{function}"
r" \ f(x) \\"
r" \hspace{1em} \mathbf{return} \ x \\"
r" \mathbf{end \ function}"
r" \end{array}"
),
),
(
"def f(a, b, c): return 3",
(
r"\begin{array}{l}"
r" \mathbf{function}"
r" \ f(a, b, c) \\"
r" \hspace{1em} \mathbf{return} \ 3 \\"
r" \mathbf{end \ function}"
r" \end{array}"
),
),
],
)
def test_visit_functiondef_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.FunctionDef)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
(
"if x < y: return x",
(
r"\mathbf{if} \ x < y \\"
r" \hspace{1em} \mathbf{return} \ x \\"
r" \mathbf{end \ if}"
),
),
(
"if True: x\nelse: y",
(
r"\mathbf{if} \ \mathrm{True} \\"
r" \hspace{1em} x \\"
r" \mathbf{else} \\"
r" \hspace{1em} y \\"
r" \mathbf{end \ if}"
),
),
],
)
def test_visit_if_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.If)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
(
"return x + y",
r"\mathbf{return} \ x + y",
),
(
"return",
r"\mathbf{return}",
),
],
)
def test_visit_return_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.Return)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
(
"while x < y: x = x + 1",
(
r"\mathbf{while} \ x < y \\"
r" \hspace{1em} x \gets x + 1 \\"
r" \mathbf{end \ while}"
),
)
],
)
def test_visit_while_ipython(code: str, latex: str) -> None:
node = ast.parse(textwrap.dedent(code)).body[0]
assert isinstance(node, ast.While)
assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex
def test_visit_while_with_else_ipython() -> None:
node = ast.parse(
textwrap.dedent(
"""
while True:
x = x
else:
x = y
"""
)
).body[0]
assert isinstance(node, ast.While)
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match="^While statement with the else clause is not supported$",
):
algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node)
def test_visit_pass_ipython() -> None:
node = ast.parse("pass").body[0]
assert isinstance(node, ast.Pass)
assert (
algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == r"\mathbf{pass}"
)
def test_visit_break_ipython() -> None:
node = ast.parse("break").body[0]
assert isinstance(node, ast.Break)
assert (
algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == r"\mathbf{break}"
)
def test_visit_continue_ipython() -> None:
node = ast.parse("continue").body[0]
assert isinstance(node, ast.Continue)
assert (
algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node)
== r"\mathbf{continue}"
)
================================================
FILE: src/latexify/codegen/codegen_utils.py
================================================
from typing import Any
from latexify import exceptions
def convert_constant(value: Any) -> str:
"""Helper to convert constant values to LaTeX.
Args:
value: A constant value.
Returns:
The LaTeX representation of `value`.
"""
if value is None or isinstance(value, bool):
return r"\mathrm{" + str(value) + "}"
if isinstance(value, (int, float, complex)):
# TODO(odashi): Support other symbols for the imaginary unit than j.
return str(value)
if isinstance(value, str):
return r'\textrm{"' + value + '"}'
if isinstance(value, bytes):
return r"\textrm{" + str(value) + "}"
if value is ...:
return r"\cdots"
raise exceptions.LatexifyNotSupportedError(
f"Unrecognized constant: {type(value).__name__}"
)
================================================
FILE: src/latexify/codegen/codegen_utils_test.py
================================================
"""Tests for latexify.codegen.codegen_utils."""
from __future__ import annotations
from typing import Any
import pytest
from latexify import exceptions
from latexify.codegen.codegen_utils import convert_constant
@pytest.mark.parametrize(
"constant,latex",
[
(None, r"\mathrm{None}"),
(True, r"\mathrm{True}"),
(False, r"\mathrm{False}"),
(123, "123"),
(456.789, "456.789"),
(-3 + 4j, "(-3+4j)"),
("string", r'\textrm{"string"}'),
(..., r"\cdots"),
],
)
def test_convert_constant(constant: Any, latex: str) -> None:
assert convert_constant(constant) == latex
def test_convert_constant_unsupported_constant() -> None:
with pytest.raises(
exceptions.LatexifyNotSupportedError, match="^Unrecognized constant: "
):
convert_constant({})
================================================
FILE: src/latexify/codegen/expression_codegen.py
================================================
"""Codegen for single expressions."""
from __future__ import annotations
import ast
import re
from latexify import analyzers, ast_utils, exceptions
from latexify.codegen import codegen_utils, expression_rules, identifier_converter
class ExpressionCodegen(ast.NodeVisitor):
"""Codegen for single expressions."""
_identifier_converter: identifier_converter.IdentifierConverter
_bin_op_rules: dict[type[ast.operator], expression_rules.BinOpRule]
_compare_ops: dict[type[ast.cmpop], str]
def __init__(
self,
*,
use_math_symbols: bool = False,
use_set_symbols: bool = False,
escape_underscores: bool = True,
) -> None:
"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with a math symbol
surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
"""
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols, escape_underscores=escape_underscores
)
self._bin_op_rules = (
expression_rules.SET_BIN_OP_RULES
if use_set_symbols
else expression_rules.BIN_OP_RULES
)
self._compare_ops = (
expression_rules.SET_COMPARE_OPS
if use_set_symbols
else expression_rules.COMPARE_OPS
)
def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)
def visit_Tuple(self, node: ast.Tuple) -> str:
"""Visit a Tuple node."""
elts = [self.visit(elt) for elt in node.elts]
return r"\mathopen{}\left( " + r", ".join(elts) + r" \mathclose{}\right)"
def visit_List(self, node: ast.List) -> str:
"""Visit a List node."""
elts = [self.visit(elt) for elt in node.elts]
return r"\mathopen{}\left[ " + r", ".join(elts) + r" \mathclose{}\right]"
def visit_Set(self, node: ast.Set) -> str:
"""Visit a Set node."""
elts = [self.visit(elt) for elt in node.elts]
return r"\mathopen{}\left\{ " + r", ".join(elts) + r" \mathclose{}\right\}"
def visit_ListComp(self, node: ast.ListComp) -> str:
"""Visit a ListComp node."""
generators = [self.visit(comp) for comp in node.generators]
return (
r"\mathopen{}\left[ "
+ self.visit(node.elt)
+ r" \mid "
+ ", ".join(generators)
+ r" \mathclose{}\right]"
)
def visit_SetComp(self, node: ast.SetComp) -> str:
"""Visit a SetComp node."""
generators = [self.visit(comp) for comp in node.generators]
return (
r"\mathopen{}\left\{ "
+ self.visit(node.elt)
+ r" \mid "
+ ", ".join(generators)
+ r" \mathclose{}\right\}"
)
def visit_comprehension(self, node: ast.comprehension) -> str:
"""Visit a comprehension node."""
target = rf"{self.visit(node.target)} \in {self.visit(node.iter)}"
if not node.ifs:
# Returns the source without parenthesis.
return target
conds = [target] + [self.visit(cond) for cond in node.ifs]
wrapped = [r"\mathopen{}\left( " + s + r" \mathclose{}\right)" for s in conds]
return r" \land ".join(wrapped)
def _generate_sum_prod(self, node: ast.Call) -> str | None:
"""Generates sum/prod expression.
Args:
node: ast.Call node containing the sum/prod invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
"""
if not node.args or not isinstance(node.args[0], ast.GeneratorExp):
return None
name = ast_utils.extract_function_name_or_none(node)
assert name in ("fsum", "sum", "prod")
command = {
"fsum": r"\sum",
"sum": r"\sum",
"prod": r"\prod",
}[name]
elt, scripts = self._get_sum_prod_info(node.args[0])
scripts_str = [rf"{command}_{{{lo}}}^{{{up}}}" for lo, up in scripts]
return (
" ".join(scripts_str)
+ rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)"
)
def _generate_matrix(self, node: ast.Call) -> str | None:
"""Generates matrix expression.
Args:
node: ast.Call node containing the ndarray invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
"""
def generate_matrix_from_array(data: list[list[str]]) -> str:
"""Helper to generate a bmatrix environment."""
contents = r" \\ ".join(" & ".join(row) for row in data)
return r"\begin{bmatrix} " + contents + r" \end{bmatrix}"
arg = node.args[0]
if not isinstance(arg, ast.List) or not arg.elts:
# Not an array or no rows
return None
row0 = arg.elts[0]
if not isinstance(row0, ast.List):
# Maybe 1 x N array
return generate_matrix_from_array([[self.visit(x) for x in arg.elts]])
if not row0.elts:
# No columns
return None
ncols = len(row0.elts)
rows: list[list[str]] = []
for row in arg.elts:
if not isinstance(row, ast.List) or len(row.elts) != ncols:
# Length mismatch
return None
rows.append([self.visit(x) for x in row.elts])
return generate_matrix_from_array(rows)
def _generate_zeros(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.zeros.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "zeros"
if len(node.args) != 1:
return None
# All args to np.zeros should be numeric.
if isinstance(node.args[0], ast.Tuple):
dims = [ast_utils.extract_int_or_none(x) for x in node.args[0].elts]
if any(x is None for x in dims):
return None
if not dims:
return "0"
if len(dims) == 1:
dims = [1, dims[0]]
dims_latex = r" \times ".join(str(x) for x in dims)
else:
dim = ast_utils.extract_int_or_none(node.args[0])
if not isinstance(dim, int):
return None
# 1 x N array of zeros
dims_latex = rf"1 \times {dim}"
return rf"\mathbf{{0}}^{{{dims_latex}}}"
def _generate_identity(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.identity.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "identity"
if len(node.args) != 1:
return None
ndims = ast_utils.extract_int_or_none(node.args[0])
if ndims is None:
return None
return rf"\mathbf{{I}}_{{{ndims}}}"
def _generate_transpose(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.transpose.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "transpose"
if len(node.args) != 1:
return None
func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^\intercal"
else:
return None
def _generate_determinant(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.det.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "det"
if len(node.args) != 1:
return None
func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
arg_id = rf"\mathbf{{{func_arg.id}}}"
return rf"\det \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
elif isinstance(func_arg, ast.List):
matrix = self._generate_matrix(node)
return rf"\det \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"
return None
def _generate_matrix_rank(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.matrix_rank.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "matrix_rank"
if len(node.args) != 1:
return None
func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
arg_id = rf"\mathbf{{{func_arg.id}}}"
return (
rf"\mathrm{{rank}} \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)"
)
elif isinstance(func_arg, ast.List):
matrix = self._generate_matrix(node)
return (
rf"\mathrm{{rank}} \mathopen{{}}\left( {matrix} \mathclose{{}}\right)"
)
return None
def _generate_matrix_power(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.matrix_power.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "matrix_power"
if len(node.args) != 2:
return None
func_arg = node.args[0]
power_arg = node.args[1]
if isinstance(power_arg, ast.Num):
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^{{{power_arg.n}}}"
elif isinstance(func_arg, ast.List):
matrix = self._generate_matrix(node)
if matrix is not None:
return rf"{matrix}^{{{power_arg.n}}}"
return None
def _generate_inv(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.inv.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "inv"
if len(node.args) != 1:
return None
func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^{{-1}}"
elif isinstance(func_arg, ast.List):
return rf"{self._generate_matrix(node)}^{{-1}}"
return None
def _generate_pinv(self, node: ast.Call) -> str | None:
"""Generates LaTeX for numpy.linalg.pinv.
Args:
node: ast.Call node containing the appropriate method invocation.
Returns:
Generated LaTeX, or None if the node has unsupported syntax.
Raises:
LatexifyError: Unsupported argument type given.
"""
name = ast_utils.extract_function_name_or_none(node)
assert name == "pinv"
if len(node.args) != 1:
return None
func_arg = node.args[0]
if isinstance(func_arg, ast.Name):
return rf"\mathbf{{{func_arg.id}}}^{{+}}"
elif isinstance(func_arg, ast.List):
return rf"{self._generate_matrix(node)}^{{+}}"
return None
def visit_Call(self, node: ast.Call) -> str:
"""Visit a Call node."""
func_name = ast_utils.extract_function_name_or_none(node)
# Special treatments for some functions.
# TODO(odashi): Move these functions to some separate utility.
if func_name in ("fsum", "sum", "prod"):
special_latex = self._generate_sum_prod(node)
elif func_name in ("array", "ndarray"):
special_latex = self._generate_matrix(node)
elif func_name == "zeros":
special_latex = self._generate_zeros(node)
elif func_name == "identity":
special_latex = self._generate_identity(node)
elif func_name == "transpose":
special_latex = self._generate_transpose(node)
elif func_name == "det":
special_latex = self._generate_determinant(node)
elif func_name == "matrix_rank":
special_latex = self._generate_matrix_rank(node)
elif func_name == "matrix_power":
special_latex = self._generate_matrix_power(node)
elif func_name == "inv":
special_latex = self._generate_inv(node)
elif func_name == "pinv":
special_latex = self._generate_pinv(node)
else:
special_latex = None
if special_latex is not None:
return special_latex
# Obtains the codegen rule.
rule = (
expression_rules.BUILTIN_FUNCS.get(func_name)
if func_name is not None
else None
)
if rule is None:
rule = expression_rules.FunctionRule(self.visit(node.func))
if rule.is_unary and len(node.args) == 1:
# Unary function. Applies the same wrapping policy with the unary operators.
precedence = expression_rules.get_precedence(node)
arg = node.args[0]
# NOTE(odashi):
# Factorial "x!" is treated as a special case: it requires both inner/outer
# parentheses for correct interpretation.
force_wrap_factorial = isinstance(arg, ast.Call) and (
func_name == "factorial"
or ast_utils.extract_function_name_or_none(arg) == "factorial"
)
# Note(odashi):
# Wrapping is also required if the argument is pow.
# https://github.com/google/latexify_py/issues/189
force_wrap_pow = isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Pow)
arg_latex = self._wrap_operand(
arg, precedence, force_wrap_factorial or force_wrap_pow
)
elements = [rule.left, arg_latex, rule.right]
else:
arg_latex = ", ".join(self.visit(arg) for arg in node.args)
if rule.is_wrapped:
elements = [rule.left, arg_latex, rule.right]
else:
elements = [
rule.left,
r"\mathopen{}\left(",
arg_latex,
r"\mathclose{}\right)",
rule.right,
]
return " ".join(x for x in elements if x)
def visit_Attribute(self, node: ast.Attribute) -> str:
"""Visit an Attribute node."""
vstr = self.visit(node.value)
astr = self._identifier_converter.convert(node.attr)[0]
return vstr + "." + astr
def visit_Name(self, node: ast.Name) -> str:
"""Visit a Name node."""
return self._identifier_converter.convert(node.id)[0]
# From Python 3.8
def visit_Constant(self, node: ast.Constant) -> str:
"""Visit a Constant node."""
return codegen_utils.convert_constant(node.value)
# Until Python 3.7
def visit_Num(self, node: ast.Num) -> str:
"""Visit a Num node."""
return codegen_utils.convert_constant(node.n)
# Until Python 3.7
def visit_Str(self, node: ast.Str) -> str:
"""Visit a Str node."""
return codegen_utils.convert_constant(node.s)
# Until Python 3.7
def visit_Bytes(self, node: ast.Bytes) -> str:
"""Visit a Bytes node."""
return codegen_utils.convert_constant(node.s)
# Until Python 3.7
def visit_NameConstant(self, node: ast.NameConstant) -> str:
"""Visit a NameConstant node."""
return codegen_utils.convert_constant(node.value)
# Until Python 3.7
def visit_Ellipsis(self, node: ast.Ellipsis) -> str:
"""Visit an Ellipsis node."""
return codegen_utils.convert_constant(...)
def _wrap_operand(
self, child: ast.expr, parent_prec: int, force_wrap: bool = False
) -> str:
"""Wraps the operand subtree with parentheses.
Args:
child: Operand subtree.
parent_prec: Precedence of the parent operator.
force_wrap: Whether to wrap the operand or not when the precedence is equal.
Returns:
LaTeX form of `child`, with or without surrounding parentheses.
"""
latex = self.visit(child)
child_prec = expression_rules.get_precedence(child)
if force_wrap or child_prec < parent_prec:
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
return latex
def _wrap_binop_operand(
self,
child: ast.expr,
parent_prec: int,
operand_rule: expression_rules.BinOperandRule,
) -> str:
"""Wraps the operand subtree of BinOp with parentheses.
Args:
child: Operand subtree.
parent_prec: Precedence of the parent operator.
operand_rule: Syntax rule of this operand.
Returns:
LaTeX form of the `child`, with or without surrounding parentheses.
"""
if not operand_rule.wrap:
return self.visit(child)
if isinstance(child, ast.Call):
child_fn_name = ast_utils.extract_function_name_or_none(child)
rule = (
expression_rules.BUILTIN_FUNCS.get(child_fn_name)
if child_fn_name is not None
else None
)
if rule is not None and rule.is_wrapped:
return self.visit(child)
if not isinstance(child, ast.BinOp):
return self._wrap_operand(child, parent_prec)
latex = self.visit(child)
if expression_rules.BIN_OP_RULES[type(child.op)].is_wrapped:
return latex
child_prec = expression_rules.get_precedence(child)
if child_prec > parent_prec or (
child_prec == parent_prec and not operand_rule.force
):
return latex
return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)"
_l_bracket_pattern = re.compile(r"^\\mathopen.*")
_r_bracket_pattern = re.compile(r".*\\mathclose[^ ]+$")
_r_word_pattern = re.compile(r"\\mathrm\{[^ ]+\}$")
def _should_remove_multiply_op(
self, l_latex: str, r_latex: str, l_expr: ast.expr, r_expr: ast.expr
):
"""Determine whether the multiply operator should be removed or not.
See also:
https://github.com/google/latexify_py/issues/89#issuecomment-1344967636
This is an ad-hoc implementation.
This function doesn't fully implements the above requirements, but only
essential ones necessary to release v0.3.
"""
# NOTE(odashi): For compatibility with Python 3.7, we compare the generated
# caracter type directly to determine the "numeric" type.
if isinstance(l_expr, ast.Call):
l_type = "f"
elif self._r_bracket_pattern.match(l_latex):
l_type = "b"
elif self._r_word_pattern.match(l_latex):
l_type = "w"
elif l_latex[-1].isnumeric():
l_type = "n"
else:
le = l_expr
while True:
if isinstance(le, ast.UnaryOp):
le = le.operand
elif isinstance(le, ast.BinOp):
le = le.right
elif isinstance(le, ast.Compare):
le = le.comparators[-1]
elif isinstance(le, ast.BoolOp):
le = le.values[-1]
else:
break
l_type = "a" if isinstance(le, ast.Name) and len(le.id) == 1 else "m"
if isinstance(r_expr, ast.Call):
r_type = "f"
elif self._l_bracket_pattern.match(r_latex):
r_type = "b"
elif r_latex.startswith("\\mathrm"):
r_type = "w"
elif r_latex[0].isnumeric():
r_type = "n"
else:
re = r_expr
while True:
if isinstance(re, ast.UnaryOp):
if isinstance(re.op, ast.USub):
# NOTE(odashi): Unary "-" always require \cdot.
return False
re = re.operand
elif isinstance(re, ast.BinOp):
re = re.left
elif isinstance(re, ast.Compare):
re = re.left
elif isinstance(re, ast.BoolOp):
re = re.values[0]
else:
break
r_type = "a" if isinstance(re, ast.Name) and len(re.id) == 1 else "m"
if r_type == "n":
return False
if l_type in "bn":
return True
if l_type in "am" and r_type in "am":
return True
return False
def visit_BinOp(self, node: ast.BinOp) -> str:
"""Visit a BinOp node."""
prec = expression_rules.get_precedence(node)
rule = self._bin_op_rules[type(node.op)]
lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left)
rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right)
if type(node.op) in [ast.Mult, ast.MatMult]:
if self._should_remove_multiply_op(lhs, rhs, node.left, node.right):
return f"{rule.latex_left}{lhs} {rhs}{rule.latex_right}"
return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}"
def visit_UnaryOp(self, node: ast.UnaryOp) -> str:
"""Visit a UnaryOp node."""
latex = self._wrap_operand(node.operand, expression_rules.get_precedence(node))
return expression_rules.UNARY_OPS[type(node.op)] + latex
def visit_Compare(self, node: ast.Compare) -> str:
"""Visit a Compare node."""
parent_prec = expression_rules.get_precedence(node)
lhs = self._wrap_operand(node.left, parent_prec)
ops = [self._compare_ops[type(x)] for x in node.ops]
rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators]
ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)]
return lhs + "".join(ops_rhs)
def visit_BoolOp(self, node: ast.BoolOp) -> str:
"""Visit a BoolOp node."""
parent_prec = expression_rules.get_precedence(node)
values = [self._wrap_operand(x, parent_prec) for x in node.values]
op = f" {expression_rules.BOOL_OPS[type(node.op)]} "
return op.join(values)
def visit_IfExp(self, node: ast.IfExp) -> str:
"""Visit an IfExp node"""
latex = r"\left\{ \begin{array}{ll} "
current_expr: ast.expr = node
while isinstance(current_expr, ast.IfExp):
cond_latex = self.visit(current_expr.test)
true_latex = self.visit(current_expr.body)
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
current_expr = current_expr.orelse
latex += self.visit(current_expr)
return latex + r", & \mathrm{otherwise} \end{array} \right."
def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None:
"""Helper to process range(...) for sum and prod functions.
Args:
node: comprehension node to be analyzed.
Returns:
Tuple of following strings:
- lower_rhs
- upper
which are used in _get_sum_prod_info, or None if the analysis failed.
"""
if not (
isinstance(node.iter, ast.Call)
and isinstance(node.iter.func, ast.Name)
and node.iter.func.id == "range"
):
return None
try:
range_info = analyzers.analyze_range(node.iter)
except exceptions.LatexifyError:
return None
if (
# Only accepts ascending order with step size 1.
range_info.step_int != 1
or (
range_info.start_int is not None
and range_info.stop_int is not None
and range_info.start_int >= range_info.stop_int
)
):
return None
if range_info.start_int is None:
lower_rhs = self.visit(range_info.start)
else:
lower_rhs = str(range_info.start_int)
if range_info.stop_int is None:
upper = self.visit(analyzers.reduce_stop_parameter(range_info.stop))
else:
upper = str(range_info.stop_int - 1)
return lower_rhs, upper
def _get_sum_prod_info(
self, node: ast.GeneratorExp
) -> tuple[str, list[tuple[str, str]]]:
r"""Process GeneratorExp for sum and prod functions.
Args:
node: GeneratorExp node to be analyzed.
Returns:
Tuple of following strings:
- elt
- scripts
which are used to represent sum/prod operators as follows:
\sum_{scripts[0][0]}^{scripts[0][1]}
\sum_{scripts[1][0]}^{scripts[1][1]}
...
{elt}
Raises:
LateixfyError: Unsupported AST is given.
"""
elt = self.visit(node.elt)
scripts: list[tuple[str, str]] = []
for comp in node.generators:
range_args = self._get_sum_prod_range(comp)
if range_args is not None and not comp.ifs:
target = self.visit(comp.target)
lower_rhs, upper = range_args
lower = f"{target} = {lower_rhs}"
else:
lower = self.visit(comp) # Use a usual comprehension form.
upper = ""
scripts.append((lower, upper))
return elt, scripts
# Until 3.8
def visit_Index(self, node: ast.Index) -> str:
"""Visit an Index node."""
return self.visit(node.value) # type: ignore[attr-defined]
def _convert_nested_subscripts(self, node: ast.Subscript) -> tuple[str, list[str]]:
"""Helper function to convert nested subscription.
This function converts x[i][j][...] to "x" and ["i", "j", ...]
Args:
node: ast.Subscript node to be converted.
Returns:
Tuple of following strings:
- The root value of the subscription.
- Sequence of incices.
"""
if isinstance(node.value, ast.Subscript):
value, indices = self._convert_nested_subscripts(node.value)
else:
value = self.visit(node.value)
indices = []
indices.append(self.visit(node.slice))
return value, indices
def visit_Subscript(self, node: ast.Subscript) -> str:
"""Visitor a Subscript node."""
value, indices = self._convert_nested_subscripts(node)
# TODO(odashi):
# "[i][j][...]" may be a possible representation as well as "i, j. ..."
indices_str = ", ".join(indices)
return f"{value}_{{{indices_str}}}"
================================================
FILE: src/latexify/codegen/expression_codegen_test.py
================================================
"""Tests for latexify.codegen.expression_codegen."""
from __future__ import annotations
import ast
import pytest
from latexify import ast_utils, exceptions
from latexify.codegen import expression_codegen
def test_generic_visit() -> None:
class UnknownNode(ast.AST):
pass
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match=r"^Unsupported AST: UnknownNode$",
):
expression_codegen.ExpressionCodegen().visit(UnknownNode())
@pytest.mark.parametrize(
"code,latex",
[
("()", r"\mathopen{}\left( \mathclose{}\right)"),
("(x,)", r"\mathopen{}\left( x \mathclose{}\right)"),
("(x, y)", r"\mathopen{}\left( x, y \mathclose{}\right)"),
("(x, y, z)", r"\mathopen{}\left( x, y, z \mathclose{}\right)"),
],
)
def test_visit_tuple(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Tuple)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
("[]", r"\mathopen{}\left[ \mathclose{}\right]"),
("[x]", r"\mathopen{}\left[ x \mathclose{}\right]"),
("[x, y]", r"\mathopen{}\left[ x, y \mathclose{}\right]"),
("[x, y, z]", r"\mathopen{}\left[ x, y, z \mathclose{}\right]"),
],
)
def test_visit_list(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.List)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
# TODO(odashi): Support set().
# ("set()", r"\mathopen{}\left\{ \mathclose{}\right\}"),
("{x}", r"\mathopen{}\left\{ x \mathclose{}\right\}"),
("{x, y}", r"\mathopen{}\left\{ x, y \mathclose{}\right\}"),
("{x, y, z}", r"\mathopen{}\left\{ x, y, z \mathclose{}\right\}"),
],
)
def test_visit_set(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Set)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
("[i for i in n]", r"\mathopen{}\left[ i \mid i \in n \mathclose{}\right]"),
(
"[i for i in n if i > 0]",
r"\mathopen{}\left[ i \mid"
r" \mathopen{}\left( i \in n \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \mathclose{}\right]",
),
(
"[i for i in n if i > 0 if f(i)]",
r"\mathopen{}\left[ i \mid"
r" \mathopen{}\left( i \in n \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \land \mathopen{}\left( f \mathopen{}\left("
r" i \mathclose{}\right) \mathclose{}\right)"
r" \mathclose{}\right]",
),
(
"[i for k in n for i in k]",
r"\mathopen{}\left[ i \mid k \in n, i \in k" r" \mathclose{}\right]",
),
(
"[i for k in n for i in k if i > 0]",
r"\mathopen{}\left[ i \mid"
r" k \in n,"
r" \mathopen{}\left( i \in k \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \mathclose{}\right]",
),
(
"[i for k in n if f(k) for i in k if i > 0]",
r"\mathopen{}\left[ i \mid"
r" \mathopen{}\left( k \in n \mathclose{}\right)"
r" \land \mathopen{}\left( f \mathopen{}\left("
r" k \mathclose{}\right) \mathclose{}\right),"
r" \mathopen{}\left( i \in k \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \mathclose{}\right]",
),
],
)
def test_visit_listcomp(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.ListComp)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
("{i for i in n}", r"\mathopen{}\left\{ i \mid i \in n \mathclose{}\right\}"),
(
"{i for i in n if i > 0}",
r"\mathopen{}\left\{ i \mid"
r" \mathopen{}\left( i \in n \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \mathclose{}\right\}",
),
(
"{i for i in n if i > 0 if f(i)}",
r"\mathopen{}\left\{ i \mid"
r" \mathopen{}\left( i \in n \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \land \mathopen{}\left( f \mathopen{}\left("
r" i \mathclose{}\right) \mathclose{}\right)"
r" \mathclose{}\right\}",
),
(
"{i for k in n for i in k}",
r"\mathopen{}\left\{ i \mid k \in n, i \in k" r" \mathclose{}\right\}",
),
(
"{i for k in n for i in k if i > 0}",
r"\mathopen{}\left\{ i \mid"
r" k \in n,"
r" \mathopen{}\left( i \in k \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \mathclose{}\right\}",
),
(
"{i for k in n if f(k) for i in k if i > 0}",
r"\mathopen{}\left\{ i \mid"
r" \mathopen{}\left( k \in n \mathclose{}\right)"
r" \land \mathopen{}\left( f \mathopen{}\left("
r" k \mathclose{}\right) \mathclose{}\right),"
r" \mathopen{}\left( i \in k \mathclose{}\right)"
r" \land \mathopen{}\left( i > 0 \mathclose{}\right)"
r" \mathclose{}\right\}",
),
],
)
def test_visit_setcomp(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.SetComp)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
("foo(x)", r"\mathrm{foo} \mathopen{}\left( x \mathclose{}\right)"),
("f(x)", r"f \mathopen{}\left( x \mathclose{}\right)"),
("f(-x)", r"f \mathopen{}\left( -x \mathclose{}\right)"),
("f(x + y)", r"f \mathopen{}\left( x + y \mathclose{}\right)"),
(
"f(f(x))",
r"f \mathopen{}\left("
r" f \mathopen{}\left( x \mathclose{}\right)"
r" \mathclose{}\right)",
),
("f(sqrt(x))", r"f \mathopen{}\left( \sqrt{ x } \mathclose{}\right)"),
("f(sin(x))", r"f \mathopen{}\left( \sin x \mathclose{}\right)"),
("f(factorial(x))", r"f \mathopen{}\left( x ! \mathclose{}\right)"),
("f(x, y)", r"f \mathopen{}\left( x, y \mathclose{}\right)"),
("sqrt(x)", r"\sqrt{ x }"),
("sqrt(-x)", r"\sqrt{ -x }"),
("sqrt(x + y)", r"\sqrt{ x + y }"),
("sqrt(f(x))", r"\sqrt{ f \mathopen{}\left( x \mathclose{}\right) }"),
("sqrt(sqrt(x))", r"\sqrt{ \sqrt{ x } }"),
("sqrt(sin(x))", r"\sqrt{ \sin x }"),
("sqrt(factorial(x))", r"\sqrt{ x ! }"),
("sin(x)", r"\sin x"),
("sin(-x)", r"\sin \mathopen{}\left( -x \mathclose{}\right)"),
("sin(x + y)", r"\sin \mathopen{}\left( x + y \mathclose{}\right)"),
("sin(f(x))", r"\sin f \mathopen{}\left( x \mathclose{}\right)"),
("sin(sqrt(x))", r"\sin \sqrt{ x }"),
("sin(sin(x))", r"\sin \sin x"),
("sin(factorial(x))", r"\sin \mathopen{}\left( x ! \mathclose{}\right)"),
("factorial(x)", r"x !"),
("factorial(-x)", r"\mathopen{}\left( -x \mathclose{}\right) !"),
("factorial(x + y)", r"\mathopen{}\left( x + y \mathclose{}\right) !"),
(
"factorial(f(x))",
r"\mathopen{}\left("
r" f \mathopen{}\left( x \mathclose{}\right)"
r" \mathclose{}\right) !",
),
("factorial(sqrt(x))", r"\mathopen{}\left( \sqrt{ x } \mathclose{}\right) !"),
("factorial(sin(x))", r"\mathopen{}\left( \sin x \mathclose{}\right) !"),
("factorial(factorial(x))", r"\mathopen{}\left( x ! \mathclose{}\right) !"),
],
)
def test_visit_call(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
("log(x)**2", r"\mathopen{}\left( \log x \mathclose{}\right)^{2}"),
("log(x**2)", r"\log \mathopen{}\left( x^{2} \mathclose{}\right)"),
(
"log(x**2)**3",
r"\mathopen{}\left("
r" \log \mathopen{}\left( x^{2} \mathclose{}\right)"
r" \mathclose{}\right)^{3}",
),
],
)
def test_visit_call_with_pow(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, (ast.Call, ast.BinOp))
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"src_suffix,dest_suffix",
[
# No arguments
("()", r" \mathopen{}\left( \mathclose{}\right)"),
# No comprehension
("(x)", r" x"),
(
"([1, 2])",
r" \mathopen{}\left[ 1, 2 \mathclose{}\right]",
),
(
"({1, 2})",
r" \mathopen{}\left\{ 1, 2 \mathclose{}\right\}",
),
("(f(x))", r" f \mathopen{}\left( x \mathclose{}\right)"),
# Single comprehension
("(i for i in x)", r"_{i \in x}^{} \mathopen{}\left({i}\mathclose{}\right)"),
(
"(i for i in [1, 2])",
r"_{i \in \mathopen{}\left[ 1, 2 \mathclose{}\right]}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in {1, 2})",
r"_{i \in \mathopen{}\left\{ 1, 2 \mathclose{}\right\}}^{}"
r" \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in f(x))",
r"_{i \in f \mathopen{}\left( x \mathclose{}\right)}^{}"
r" \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n))",
r"_{i = 0}^{n - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n + 1))",
r"_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n + 2))",
r"_{i = 0}^{n + 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
# ast.parse() does not recognize negative integers.
"(i for i in range(n - -1))",
r"_{i = 0}^{n - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n - 1))",
r"_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n + m))",
r"_{i = 0}^{n + m - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n - m))",
r"_{i = 0}^{n - m - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(3))",
r"_{i = 0}^{2} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(3 + 1))",
r"_{i = 0}^{3} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(3 + 2))",
r"_{i = 0}^{3 + 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(3 - 1))",
r"_{i = 0}^{3 - 2} \mathopen{}\left({i}\mathclose{}\right)",
),
(
# ast.parse() does not recognize negative integers.
"(i for i in range(3 - -1))",
r"_{i = 0}^{3 - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(3 + m))",
r"_{i = 0}^{3 + m - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(3 - m))",
r"_{i = 0}^{3 - m - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n, m))",
r"_{i = n}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(1, m))",
r"_{i = 1}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n, 3))",
r"_{i = n}^{2} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in range(n, m, k))",
r"_{i \in \mathrm{range} \mathopen{}\left( n, m, k \mathclose{}\right)}^{}"
r" \mathopen{}\left({i}\mathclose{}\right)",
),
],
)
def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]:
node = ast_utils.parse_expr(src_fn + src_suffix)
assert isinstance(node, ast.Call)
assert (
expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix
)
@pytest.mark.parametrize(
"code,latex",
[
# 2 clauses
(
"sum(i for y in x for i in y)",
r"\sum_{y \in x}^{} \sum_{i \in y}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
(
"sum(i for y in x for z in y for i in z)",
r"\sum_{y \in x}^{} \sum_{z \in y}^{} \sum_{i \in z}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
# 3 clauses
(
"prod(i for y in x for i in y)",
r"\prod_{y \in x}^{} \prod_{i \in y}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
(
"prod(i for y in x for z in y for i in z)",
r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
# reduce stop parameter
(
"sum(i for i in range(n+1))",
r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"prod(i for i in range(n-1))",
r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)",
),
# reduce stop parameter
(
"sum(i for i in range(n+1))",
r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)",
),
(
"prod(i for i in range(n-1))",
r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)",
),
],
)
def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"src_suffix,dest_suffix",
[
(
"(i for i in x if i < y)",
r"_{\mathopen{}\left( i \in x \mathclose{}\right) "
r"\land \mathopen{}\left( i < y \mathclose{}\right)}^{} "
r"\mathopen{}\left({i}\mathclose{}\right)",
),
(
"(i for i in x if i < y if f(i))",
r"_{\mathopen{}\left( i \in x \mathclose{}\right)"
r" \land \mathopen{}\left( i < y \mathclose{}\right)"
r" \land \mathopen{}\left( f \mathopen{}\left("
r" i \mathclose{}\right) \mathclose{}\right)}^{}"
r" \mathopen{}\left({i}\mathclose{}\right)",
),
],
)
def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None:
for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]:
node = ast_utils.parse_expr(src_fn + src_suffix)
assert isinstance(node, ast.Call)
assert (
expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix
)
@pytest.mark.parametrize(
"code,latex",
[
(
"x if x < y else y",
r"\left\{ \begin{array}{ll}"
r" x, & \mathrm{if} \ x < y \\"
r" y, & \mathrm{otherwise}"
r" \end{array} \right.",
),
(
"x if x < y else (y if y < z else z)",
r"\left\{ \begin{array}{ll}"
r" x, & \mathrm{if} \ x < y \\"
r" y, & \mathrm{if} \ y < z \\"
r" z, & \mathrm{otherwise}"
r" \end{array} \right.",
),
(
"x if x < y else (y if y < z else (z if z < w else w))",
r"\left\{ \begin{array}{ll}"
r" x, & \mathrm{if} \ x < y \\"
r" y, & \mathrm{if} \ y < z \\"
r" z, & \mathrm{if} \ z < w \\"
r" w, & \mathrm{otherwise}"
r" \end{array} \right.",
),
],
)
def test_if_then_else(code: str, latex: str) -> None:
node = ast_utils.parse_expr(code)
assert isinstance(node, ast.IfExp)
assert expression_codegen.ExpressionCodegen().visit(node) == latex
@pytest.mark.parametrize(
"code,latex",
[
# x op y
("x**y", r"x^{y}"),
("x * y", r"x y"),
("x @ y", r"x y"),
("x / y", r"\frac{x}{y}"),
("x // y", r"\left\lfloor\frac{x}{y}\right\rfloor"),
("x % y", r"x \mathbin{\%} y"),
("x + y", r"x + y"),
("x - y", r"x - y"),
("x << y", r"x \ll y"),
("x >> y", r"x \gg y"),
("x & y", r"x \mathbin{\&} y"),
("x ^ y", r"x \oplus y"),
("x | y", R"x \mathbin{|} y"),
# (x op y) op z
("(x**y)**z", r"\mathopen{}\left( x^{y} \mathclose{}\right)^{z}"),
("(x * y) * z", r"x y z"),
("(x @ y) @ z", r"x y z"),
("(x / y) / z", r"\frac{\frac{x}{y}}{z}"),
(
"(x // y) // z",
r"\left\lfloor\frac{\left\lfloor\frac{x}{y}\right\rfloor}{z}\right\rfloor",
),
("(x % y) % z", r"x \mathbin{\%} y \mathbin{\%} z"),
("(x + y) + z", r"x + y + z"),
("(x - y) - z", r"x - y - z"),
("(x << y) << z", r"x \ll y \ll z"),
("(x >> y) >> z", r"x \gg y \gg z"),
("(x & y) & z", r"x \mathbin{\&} y \mathbin{\&} z"),
("(x ^ y) ^ z", r"x \oplus y \oplus z"),
("(x | y) | z", r"x \mathbin{|} y \mathbin{|} z"),
# x op (y op z)
("x**(y**z)", r"x^{y^{z}}"),
("x * (y * z)", r"x y z"),
("x @ (y @ z)", r"x y z"),
("x / (y / z)", r"\frac{x}{\frac{y}{z}}"),
(
"x // (y // z)",
r"\left\lfloor\frac{x}{\left\lfloor\frac{y}{z}\right\rfloor}\right\rfloor",
),
(
"x % (y % z)",
r"x \mathbin{\%} \mathopen{}\left( y \mathbin{\%} z \mathclose{}\right)",
),
("x + (y + z)", r"x + y + z"),
("x - (y - z)", r"x - \mathopen{}\left( y - z \mathclose{}\right)"),
("x << (y << z)", r"x \ll \mathopen{}\left( y \ll z \mathclose{}\right)"),
("x >> (y >> z)", r"x \gg \mathopen{}\left( y \gg z \mathclose{}\right)"),
("x & (y & z)", r"x \mathbin{\&} y \mathbin{\&} z"),
("x ^ (y ^ z)", r"x \oplus y \oplus z"),
("x | (y | z)", r"x \mathbin{|} y \mathbin{|} z"),
# x OP y op z
("x**y * z", r"x^{y} z"),
("x * y + z", r"x y + z"),
("x @ y + z", r"x y + z"),
("x / y + z", r"\frac{x}{y} + z"),
("x // y + z", r"\left\lfloor\frac{x}{y}\right\rfloor + z"),
("x % y + z", r"x \mathbin{\%} y + z"),
("x + y << z", r"x + y \ll z"),
("x - y << z", r"x - y \ll z"),
("x << y & z", r"x \ll y \mathbin{\&} z"),
("x >> y & z", r"x \gg y \mathbin{\&} z"),
("x & y ^ z", r"x \mathbin{\&} y \oplus z"),
("x ^ y | z", r"x \oplus y \mathbin{|} z"),
# x OP (y op z)
("x**(y * z)", r"x^{y z}"),
("x * (y + z)", r"x \cdot \mathopen{}\left( y + z \mathclose{}\right)"),
("x @ (y + z)", r"x \cdot \mathopen{}\left( y + z \mathclose{}\right)"),
("x / (y + z)", r"\frac{x}{y + z}"),
("x // (y + z)", r"\left\lfloor\frac{x}{y + z}\right\rfloor"),
("x % (y + z)", r"x \mathbin{\%} \mathopen{}\left( y + z \mathclose{}\right)"),
("x + (y << z)", r"x + \mathopen{}\left( y \ll z \mathclose{}\right)"),
("x - (y << z)", r"x - \mathopen{}\left( y \ll z \mathclose{}\right)"),
(
"x << (y & z)",
r"x \ll \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)",
),
(
"x >> (y & z)",
r"x \gg \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)",
),
(
"x & (y ^ z)",
r"x \mathbin{\&} \mathopen{}\left( y \oplus z \mathclose{}\right)",
),
(
"x ^ (y | z)",
r"x \oplus \mathopen{}\left( y \mathbin{|} z \mathclose{}\right)",
),
# x op y OP z
("x * y**z", r"x y^{z}"),
("x + y * z", r"x + y z"),
("x + y @ z", r"x + y z"),
("x + y / z", r"x + \frac{y}{z}"),
("x + y // z", r"x + \left\lfloor\frac{y}{z}\right\rfloor"),
("x + y % z", r"x + y \mathbin{\%} z"),
("x << y + z", r"x \ll y + z"),
("x << y - z", r"x \ll y - z"),
("x & y << z", r"x \mathbin{\&} y \ll z"),
("x & y >> z", r"x \mathbin{\&} y \gg z"),
("x ^ y & z", r"x \oplus y \mathbin{\&} z"),
("x | y ^ z", r"x \mathbin{|} y \oplus z"),
# (x op y) OP z
("(x * y)**z", r"\mathopen{}\left( x y \mathclose{}\right)^{z}"),
("(x + y) * z", r"\mathopen{}\left( x + y \mathclose{}\right) z"),
("(x + y) @ z", r"\mathopen{}\left( x + y \mathclose{}\right) z"),
("(x + y) / z", r"\frac{x + y}{z}"),
("(x + y) // z", r"\left\lfloor\frac{x + y}{z}\right\rfloor"),
("(x + y) % z", r"\mathopen{}\left( x + y \mathclose{}\right) \mathbin{\%} z"),
("(x << y) + z", r"\mathopen{}\left( x \ll y \mathclose{}\right) + z"),
("(x << y) - z", r"\mathopen{}\left( x \ll y \mathclose{}\right) - z"),
(
"(x & y) << z",
r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \ll z",
),
(
"(x & y) >> z",
r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \gg z",
),
(
"(x ^ y) & z",
r"\mathopen{}\left( x \oplus y \mathclose{}\right) \mathbin{\&} z",
),
(
"(x | y) ^ z",
r"\mathopen{}\left( x \mathbin{|} y \mathclose{}\right) \oplus z",
),
# is_wrapped
("(x // y)**z", r"\left\lfloor\frac{x}{y}\right\rfloor^{z}"),
# With Call
("x**f(y)", r"x^{f \mathopen{}\left( y \mathclose{}\right)}"),
(
"f(x)**y",
r"\mathopen{}\left("
r" f \mathopen{}\left( x \mathclose{}\right)"
r" \mathclose{}\right)^{y}",
),
("x * f(y)", r"x \cdot f \mathopen{}\left( y \mathclose{}\right)"),
("f(x) * y", r"f \mathopen{}\left( x \mathclose{}\right) \cdot y"),
("x / f(y)", r"\frac{x}{f \mathopen{}\left( y \mathclose{}\right)}"),
("f(x) / y", r"\frac{f \mathopen{}\left( x \mathclose{}\right)}{y}"),
("x + f(y)", r"x + f \mathopen{}\left( y \mathclose{}\right)"),
("f(x) + y", r"f \mathopen{}\left( x \mathclose{}\right) + y"),
# With is_wrapped Call
("sqrt(x) ** y", r"\sqrt{ x }^{y}"),
# With UnaryOp
("x**-y", r"x^{-y}"),
("(-x)**y", r"\mathopen{}\left( -x \mathclose{}\right)^{y}"),
("x * -y", r"x \cdot -y"),
("-x * y", r"-x y"),
("x / -y", r"\frac{x}{-y}"),
("-x / y", r"\frac{-x}{y}"),
("x + -y", r"x + -y"),
("-x + y", r"-x + y"),
# With Compare
("x**(y == z)", r"x^{y = z}"),
("(x == y)**z", r"\mathopen{}\left( x = y \mathclose{}\right)^{z}"),
("x * (y == z)", r"x \cdot \mathopen{}\left( y = z \mathclose{}\right)"),
("(x == y) * z", r"\mathopen{}\left( x = y \mathclose{}\right) z"),
("x / (y == z)", r"\frac{x}{y = z}"),
("(x == y) / z", r"\frac{x = y}{z}"),
("x + (y == z)", r"x + \mathopen{}\left( y = z \mathclose{}\right)"),
("(x == y) + z", r"\mathopen{}\left( x = y \mathclose{}\right) + z"),
# With BoolOp
("x**(y and z)", r"x^{y \land z}"),
("(x and y)**z", r"\mathopen{}\left( x \land y \mathclose{}\right)^{z}"),
("x * (y and z)", r"x \cdot \mathopen{}\left( y \land z \mathclose{}\right)"),
("(x and y) * z", r"\mathopen{}\left( x \land y \mathclose{}\right) z"),
("x / (y and z)", r"\frac{x}{y \land z}"),
("(x and y) / z", r"\frac{x \land y}{z}"),
("x + (y and z)", r"x + \mathopen{}\left( y \land z \mathclose{}\right)"),
("(x and y) + z", r"\mathopen{}\left( x \land y \mathclose{}\right) + z"),
],
)
def test_visit_binop(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BinOp)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
# With literals
("+x", r"+x"),
("-x", r"-x"),
("~x", r"\mathord{\sim} x"),
("not x", r"\lnot x"),
# With Call
("+f(x)", r"+f \mathopen{}\left( x \mathclose{}\right)"),
("-f(x)", r"-f \mathopen{}\left( x \mathclose{}\right)"),
("~f(x)", r"\mathord{\sim} f \mathopen{}\left( x \mathclose{}\right)"),
("not f(x)", r"\lnot f \mathopen{}\left( x \mathclose{}\right)"),
# With BinOp
("+(x + y)", r"+\mathopen{}\left( x + y \mathclose{}\right)"),
("-(x + y)", r"-\mathopen{}\left( x + y \mathclose{}\right)"),
("~(x + y)", r"\mathord{\sim} \mathopen{}\left( x + y \mathclose{}\right)"),
("not x + y", r"\lnot \mathopen{}\left( x + y \mathclose{}\right)"),
# With Compare
("+(x == y)", r"+\mathopen{}\left( x = y \mathclose{}\right)"),
("-(x == y)", r"-\mathopen{}\left( x = y \mathclose{}\right)"),
("~(x == y)", r"\mathord{\sim} \mathopen{}\left( x = y \mathclose{}\right)"),
("not x == y", r"\lnot \mathopen{}\left( x = y \mathclose{}\right)"),
# With BoolOp
("+(x and y)", r"+\mathopen{}\left( x \land y \mathclose{}\right)"),
("-(x and y)", r"-\mathopen{}\left( x \land y \mathclose{}\right)"),
(
"~(x and y)",
r"\mathord{\sim} \mathopen{}\left( x \land y \mathclose{}\right)",
),
("not (x and y)", r"\lnot \mathopen{}\left( x \land y \mathclose{}\right)"),
],
)
def test_visit_unaryop(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.UnaryOp)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
# 1 comparator
("a == b", "a = b"),
("a > b", "a > b"),
("a >= b", r"a \ge b"),
("a in b", r"a \in b"),
("a is b", r"a \equiv b"),
("a is not b", r"a \not\equiv b"),
("a < b", "a < b"),
("a <= b", r"a \le b"),
("a != b", r"a \ne b"),
("a not in b", r"a \notin b"),
# 2 comparators
("a == b == c", "a = b = c"),
("a == b > c", "a = b > c"),
("a == b >= c", r"a = b \ge c"),
("a == b < c", "a = b < c"),
("a == b <= c", r"a = b \le c"),
("a > b == c", "a > b = c"),
("a > b > c", "a > b > c"),
("a > b >= c", r"a > b \ge c"),
("a >= b == c", r"a \ge b = c"),
("a >= b > c", r"a \ge b > c"),
("a >= b >= c", r"a \ge b \ge c"),
("a < b == c", "a < b = c"),
("a < b < c", "a < b < c"),
("a < b <= c", r"a < b \le c"),
("a <= b == c", r"a \le b = c"),
("a <= b < c", r"a \le b < c"),
("a <= b <= c", r"a \le b \le c"),
# With Call
("a == f(b)", r"a = f \mathopen{}\left( b \mathclose{}\right)"),
("f(a) == b", r"f \mathopen{}\left( a \mathclose{}\right) = b"),
# With BinOp
("a == b + c", r"a = b + c"),
("a + b == c", r"a + b = c"),
# With UnaryOp
("a == -b", r"a = -b"),
("-a == b", r"-a = b"),
("a == (not b)", r"a = \lnot b"),
("(not a) == b", r"\lnot a = b"),
# With BoolOp
("a == (b and c)", r"a = \mathopen{}\left( b \land c \mathclose{}\right)"),
("(a and b) == c", r"\mathopen{}\left( a \land b \mathclose{}\right) = c"),
],
)
def test_visit_compare(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Compare)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
# With literals
("a and b", r"a \land b"),
("a and b and c", r"a \land b \land c"),
("a or b", r"a \lor b"),
("a or b or c", r"a \lor b \lor c"),
("a or b and c", r"a \lor b \land c"),
(
"(a or b) and c",
r"\mathopen{}\left( a \lor b \mathclose{}\right) \land c",
),
("a and b or c", r"a \land b \lor c"),
(
"a and (b or c)",
r"a \land \mathopen{}\left( b \lor c \mathclose{}\right)",
),
# With Call
("a and f(b)", r"a \land f \mathopen{}\left( b \mathclose{}\right)"),
("f(a) and b", r"f \mathopen{}\left( a \mathclose{}\right) \land b"),
("a or f(b)", r"a \lor f \mathopen{}\left( b \mathclose{}\right)"),
("f(a) or b", r"f \mathopen{}\left( a \mathclose{}\right) \lor b"),
# With BinOp
("a and b + c", r"a \land b + c"),
("a + b and c", r"a + b \land c"),
("a or b + c", r"a \lor b + c"),
("a + b or c", r"a + b \lor c"),
# With UnaryOp
("a and not b", r"a \land \lnot b"),
("not a and b", r"\lnot a \land b"),
("a or not b", r"a \lor \lnot b"),
("not a or b", r"\lnot a \lor b"),
# With Compare
("a and b == c", r"a \land b = c"),
("a == b and c", r"a = b \land c"),
("a or b == c", r"a \lor b = c"),
("a == b or c", r"a = b \lor c"),
],
)
def test_visit_boolop(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BoolOp)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("0", "0"),
("1", "1"),
("0.0", "0.0"),
("1.5", "1.5"),
("0.0j", "0j"),
("1.0j", "1j"),
("1.5j", "1.5j"),
('"abc"', r'\textrm{"abc"}'),
('b"abc"', r"\textrm{b'abc'}"),
("None", r"\mathrm{None}"),
("False", r"\mathrm{False}"),
("True", r"\mathrm{True}"),
("...", r"\cdots"),
],
)
def test_visit_constant(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Constant)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("x[0]", "x_{0}"),
("x[0][1]", "x_{0, 1}"),
("x[0][1][2]", "x_{0, 1, 2}"),
("x[foo]", r"x_{\mathrm{foo}}"),
("x[floor(x)]", r"x_{\mathopen{}\left\lfloor x \mathclose{}\right\rfloor}"),
],
)
def test_visit_subscript(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Subscript)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("a - b", r"a \setminus b"),
("a & b", r"a \cap b"),
("a ^ b", r"a \mathbin{\triangle} b"),
("a | b", r"a \cup b"),
],
)
def test_visit_binop_use_set_symbols(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.BinOp)
assert (
expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
)
@pytest.mark.parametrize(
"code,latex",
[
("a < b", r"a \subset b"),
("a <= b", r"a \subseteq b"),
("a > b", r"a \supset b"),
("a >= b", r"a \supseteq b"),
],
)
def test_visit_compare_use_set_symbols(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Compare)
assert (
expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex
)
@pytest.mark.parametrize(
"code,latex",
[
("array(1)", r"\mathrm{array} \mathopen{}\left( 1 \mathclose{}\right)"),
(
"array([])",
r"\mathrm{array} \mathopen{}\left("
r" \mathopen{}\left[ \mathclose{}\right]"
r" \mathclose{}\right)",
),
("array([1])", r"\begin{bmatrix} 1 \end{bmatrix}"),
("array([1, 2, 3])", r"\begin{bmatrix} 1 & 2 & 3 \end{bmatrix}"),
(
"array([[]])",
r"\mathrm{array} \mathopen{}\left("
r" \mathopen{}\left[ \mathopen{}\left["
r" \mathclose{}\right] \mathclose{}\right]"
r" \mathclose{}\right)",
),
("array([[1]])", r"\begin{bmatrix} 1 \end{bmatrix}"),
("array([[1], [2], [3]])", r"\begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}"),
(
"array([[1], [2], [3, 4]])",
r"\mathrm{array} \mathopen{}\left("
r" \mathopen{}\left["
r" \mathopen{}\left[ 1 \mathclose{}\right],"
r" \mathopen{}\left[ 2 \mathclose{}\right],"
r" \mathopen{}\left[ 3, 4 \mathclose{}\right]"
r" \mathclose{}\right]"
r" \mathclose{}\right)",
),
(
"array([[1, 2], [3, 4], [5, 6]])",
r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}",
),
# Only checks two cases for ndarray.
("ndarray(1)", r"\mathrm{ndarray} \mathopen{}\left( 1 \mathclose{}\right)"),
("ndarray([1])", r"\begin{bmatrix} 1 \end{bmatrix}"),
],
)
def test_numpy_array(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("zeros(0)", r"\mathbf{0}^{1 \times 0}"),
("zeros(1)", r"\mathbf{0}^{1 \times 1}"),
("zeros(2)", r"\mathbf{0}^{1 \times 2}"),
("zeros(())", r"0"),
("zeros((0,))", r"\mathbf{0}^{1 \times 0}"),
("zeros((1,))", r"\mathbf{0}^{1 \times 1}"),
("zeros((2,))", r"\mathbf{0}^{1 \times 2}"),
("zeros((0, 0))", r"\mathbf{0}^{0 \times 0}"),
("zeros((1, 1))", r"\mathbf{0}^{1 \times 1}"),
("zeros((2, 3))", r"\mathbf{0}^{2 \times 3}"),
("zeros((0, 0, 0))", r"\mathbf{0}^{0 \times 0 \times 0}"),
("zeros((1, 1, 1))", r"\mathbf{0}^{1 \times 1 \times 1}"),
("zeros((2, 3, 5))", r"\mathbf{0}^{2 \times 3 \times 5}"),
# Unsupported
("zeros()", r"\mathrm{zeros} \mathopen{}\left( \mathclose{}\right)"),
("zeros(x)", r"\mathrm{zeros} \mathopen{}\left( x \mathclose{}\right)"),
("zeros(0, x)", r"\mathrm{zeros} \mathopen{}\left( 0, x \mathclose{}\right)"),
(
"zeros((x,))",
r"\mathrm{zeros} \mathopen{}\left("
r" \mathopen{}\left( x \mathclose{}\right)"
r" \mathclose{}\right)",
),
],
)
def test_zeros(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("identity(0)", r"\mathbf{I}_{0}"),
("identity(1)", r"\mathbf{I}_{1}"),
("identity(2)", r"\mathbf{I}_{2}"),
# Unsupported
("identity()", r"\mathrm{identity} \mathopen{}\left( \mathclose{}\right)"),
("identity(x)", r"\mathrm{identity} \mathopen{}\left( x \mathclose{}\right)"),
(
"identity(0, x)",
r"\mathrm{identity} \mathopen{}\left( 0, x \mathclose{}\right)",
),
],
)
def test_identity(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("transpose(A)", r"\mathbf{A}^\intercal"),
("transpose(b)", r"\mathbf{b}^\intercal"),
# Unsupported
("transpose()", r"\mathrm{transpose} \mathopen{}\left( \mathclose{}\right)"),
("transpose(2)", r"\mathrm{transpose} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"transpose(a, (1, 0))",
r"\mathrm{transpose} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_transpose(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("det(A)", r"\det \mathopen{}\left( \mathbf{A} \mathclose{}\right)"),
("det(b)", r"\det \mathopen{}\left( \mathbf{b} \mathclose{}\right)"),
(
"det([[1, 2], [3, 4]])",
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
),
(
"det([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
),
# Unsupported
("det()", r"\mathrm{det} \mathopen{}\left( \mathclose{}\right)"),
("det(2)", r"\mathrm{det} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"det(a, (1, 0))",
r"\mathrm{det} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_determinant(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
(
"matrix_rank(A)",
r"\mathrm{rank} \mathopen{}\left( \mathbf{A} \mathclose{}\right)",
),
(
"matrix_rank(b)",
r"\mathrm{rank} \mathopen{}\left( \mathbf{b} \mathclose{}\right)",
),
(
"matrix_rank([[1, 2], [3, 4]])",
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 \\"
r" 3 & 4 \end{bmatrix} \mathclose{}\right)",
),
(
"matrix_rank([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\"
r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)",
),
# Unsupported
(
"matrix_rank()",
r"\mathrm{matrix\_rank} \mathopen{}\left( \mathclose{}\right)",
),
(
"matrix_rank(2)",
r"\mathrm{matrix\_rank} \mathopen{}\left( 2 \mathclose{}\right)",
),
(
"matrix_rank(a, (1, 0))",
r"\mathrm{matrix\_rank} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_matrix_rank(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("matrix_power(A, 2)", r"\mathbf{A}^{2}"),
("matrix_power(b, 2)", r"\mathbf{b}^{2}"),
(
"matrix_power([[1, 2], [3, 4]], 2)",
r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{2}",
),
(
"matrix_power([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 42)",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{42}",
),
# Unsupported
(
"matrix_power()",
r"\mathrm{matrix\_power} \mathopen{}\left( \mathclose{}\right)",
),
(
"matrix_power(2)",
r"\mathrm{matrix\_power} \mathopen{}\left( 2 \mathclose{}\right)",
),
(
"matrix_power(a, (1, 0))",
r"\mathrm{matrix\_power} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_matrix_power(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("inv(A)", r"\mathbf{A}^{-1}"),
("inv(b)", r"\mathbf{b}^{-1}"),
("inv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{-1}"),
(
"inv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{-1}",
),
# Unsupported
("inv()", r"\mathrm{inv} \mathopen{}\left( \mathclose{}\right)"),
("inv(2)", r"\mathrm{inv} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"inv(a, (1, 0))",
r"\mathrm{inv} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_inv(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
@pytest.mark.parametrize(
"code,latex",
[
("pinv(A)", r"\mathbf{A}^{+}"),
("pinv(b)", r"\mathbf{b}^{+}"),
("pinv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{+}"),
(
"pinv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])",
r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{+}",
),
# Unsupported
("pinv()", r"\mathrm{pinv} \mathopen{}\left( \mathclose{}\right)"),
("pinv(2)", r"\mathrm{pinv} \mathopen{}\left( 2 \mathclose{}\right)"),
(
"pinv(a, (1, 0))",
r"\mathrm{pinv} \mathopen{}\left( a, "
r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)",
),
],
)
def test_pinv(code: str, latex: str) -> None:
tree = ast_utils.parse_expr(code)
assert isinstance(tree, ast.Call)
assert expression_codegen.ExpressionCodegen().visit(tree) == latex
# Check list for #89.
# https://github.com/google/latexify_py/issues/89#issuecomment-1344967636
@pytest.mark.parametrize(
"left,right,latex",
[
("2", "3", r"2 \cdot 3"),
("2", "y", "2 y"),
("2", "beta", r"2 \beta"),
("2", "bar", r"2 \mathrm{bar}"),
("2", "g(y)", r"2 g \mathopen{}\left( y \mathclose{}\right)"),
("2", "(u + v)", r"2 \mathopen{}\left( u + v \mathclose{}\right)"),
("x", "3", r"x \cdot 3"),
("x", "y", "x y"),
("x", "beta", r"x \beta"),
("x", "bar", r"x \cdot \mathrm{bar}"),
("x", "g(y)", r"x \cdot g \mathopen{}\left( y \mathclose{}\right)"),
("x", "(u + v)", r"x \cdot \mathopen{}\left( u + v \mathclose{}\right)"),
("alpha", "3", r"\alpha \cdot 3"),
("alpha", "y", r"\alpha y"),
("alpha", "beta", r"\alpha \beta"),
("alpha", "bar", r"\alpha \cdot \mathrm{bar}"),
("alpha", "g(y)", r"\alpha \cdot g \mathopen{}\left( y \mathclose{}\right)"),
(
"alpha",
"(u + v)",
r"\alpha \cdot \mathopen{}\left( u + v \mathclose{}\right)",
),
("foo", "3", r"\mathrm{foo} \cdot 3"),
("foo", "y", r"\mathrm{foo} \cdot y"),
("foo", "beta", r"\mathrm{foo} \cdot \beta"),
("foo", "bar", r"\mathrm{foo} \cdot \mathrm{bar}"),
(
"foo",
"g(y)",
r"\mathrm{foo} \cdot g \mathopen{}\left( y \mathclose{}\right)",
),
(
"foo",
"(u + v)",
r"\mathrm{foo} \cdot \mathopen{}\left( u + v \mathclose{}\right)",
),
("f(x)", "3", r"f \mathopen{}\left( x \mathclose{}\right) \cdot 3"),
("f(x)", "y", r"f \mathopen{}\left( x \mathclose{}\right) \cdot y"),
("f(x)", "beta", r"f \mathopen{}\left( x \mathclose{}\right) \cdot \beta"),
(
"f(x)",
"bar",
r"f \mathopen{}\left( x \mathclose{}\right) \cdot \mathrm{bar}",
),
(
"f(x)",
"g(y)",
r"f \mathopen{}\left( x \mathclose{}\right)"
r" \cdot g \mathopen{}\left( y \mathclose{}\right)",
),
(
"f(x)",
"(u + v)",
r"f \mathopen{}\left( x \mathclose{}\right)"
r" \cdot \mathopen{}\left( u + v \mathclose{}\right)",
),
("(s + t)", "3", r"\mathopen{}\left( s + t \mathclose{}\right) \cdot 3"),
("(s + t)", "y", r"\mathopen{}\left( s + t \mathclose{}\right) y"),
("(s + t)", "beta", r"\mathopen{}\left( s + t \mathclose{}\right) \beta"),
(
"(s + t)",
"bar",
r"\mathopen{}\left( s + t \mathclose{}\right) \mathrm{bar}",
),
(
"(s + t)",
"g(y)",
r"\mathopen{}\left( s + t \mathclose{}\right)"
r" g \mathopen{}\left( y \mathclose{}\right)",
),
(
"(s + t)",
"(u + v)",
r"\mathopen{}\left( s + t \mathclose{}\right)"
r" \mathopen{}\left( u + v \mathclose{}\right)",
),
],
)
def test_remove_multiply(left: str, right: str, latex: str) -> None:
for op in ["*", "@"]:
tree = ast_utils.parse_expr(f"{left} {op} {right}")
assert isinstance(tree, ast.BinOp)
assert (
expression_codegen.ExpressionCodegen(use_math_symbols=True).visit(tree)
== latex
)
================================================
FILE: src/latexify/codegen/expression_rules.py
================================================
"""Codegen rules for single expressions."""
from __future__ import annotations
import ast
import dataclasses
# Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes.
# Note that this value affects only the appearance of surrounding parentheses for each
# expression, and does not affect the AST itself.
# See also:
# https://docs.python.org/3/reference/expressions.html#operator-precedence
_PRECEDENCES: dict[type[ast.AST], int] = {
ast.Pow: 120,
ast.UAdd: 110,
ast.USub: 110,
ast.Invert: 110,
ast.Mult: 100,
ast.MatMult: 100,
ast.Div: 100,
ast.FloorDiv: 100,
ast.Mod: 100,
ast.Add: 90,
ast.Sub: 90,
ast.LShift: 80,
ast.RShift: 80,
ast.BitAnd: 70,
ast.BitXor: 60,
ast.BitOr: 50,
ast.In: 40,
ast.NotIn: 40,
ast.Is: 40,
ast.IsNot: 40,
ast.Lt: 40,
ast.LtE: 40,
ast.Gt: 40,
ast.GtE: 40,
ast.NotEq: 40,
ast.Eq: 40,
# NOTE(odashi):
# We assume that the `not` operator has the same precedence with other unary
# operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a
# high precedence.
# ast.Not: 30,
ast.Not: 110,
ast.And: 20,
ast.Or: 10,
}
# NOTE(odashi):
# Function invocation is treated as a unary operator with a higher precedence.
# This ensures that the argument with a unary operator is wrapped:
# exp(x) --> \exp x
# exp(-x) --> \exp (-x)
# -exp(x) --> - \exp x
_CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1
_INF_PRECEDENCE = 1_000_000
def get_precedence(node: ast.AST) -> int:
"""Obtains the precedence of the subtree.
Args:
node: Subtree to investigate.
Returns:
If `node` is a subtree with some operator, returns the precedence of the
operator. Otherwise, returns a number larger enough from other precedences.
"""
if isinstance(node, ast.Call):
return _CALL_PRECEDENCE
if isinstance(node, (ast.BinOp, ast.UnaryOp, ast.BoolOp)):
return _PRECEDENCES[type(node.op)]
if isinstance(node, ast.Compare):
# Compare operators have the same precedence. It is enough to check only the
# first operator.
return _PRECEDENCES[type(node.ops[0])]
return _INF_PRECEDENCE
@dataclasses.dataclass(frozen=True)
class BinOperandRule:
"""Syntax rules for operands of BinOp."""
# Whether to require wrapping operands by parentheses according to the precedence.
wrap: bool = True
# Whether to require wrapping operands by parentheses if the operand has the same
# precedence with this operator.
# This is used to control the behavior of non-associative operators.
force: bool = False
@dataclasses.dataclass(frozen=True)
class BinOpRule:
"""Syntax rules for BinOp."""
# Left/middle/right syntaxes to wrap operands.
latex_left: str
latex_middle: str
latex_right: str
# Operand rules.
operand_left: BinOperandRule = dataclasses.field(default_factory=BinOperandRule)
operand_right: BinOperandRule = dataclasses.field(default_factory=BinOperandRule)
# Whether to assume the resulting syntax is wrapped by some bracket operators.
# If True, the parent operator can avoid wrapping this operator by parentheses.
is_wrapped: bool = False
BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
ast.Pow: BinOpRule(
"",
"^{",
"}",
operand_left=BinOperandRule(force=True),
operand_right=BinOperandRule(wrap=False),
),
ast.Mult: BinOpRule("", r" \cdot ", ""),
ast.MatMult: BinOpRule("", r" \cdot ", ""),
ast.Div: BinOpRule(
r"\frac{",
"}{",
"}",
operand_left=BinOperandRule(wrap=False),
operand_right=BinOperandRule(wrap=False),
),
ast.FloorDiv: BinOpRule(
r"\left\lfloor\frac{",
"}{",
r"}\right\rfloor",
operand_left=BinOperandRule(wrap=False),
operand_right=BinOperandRule(wrap=False),
is_wrapped=True,
),
ast.Mod: BinOpRule(
"", r" \mathbin{\%} ", "", operand_right=BinOperandRule(force=True)
),
ast.Add: BinOpRule("", " + ", ""),
ast.Sub: BinOpRule("", " - ", "", operand_right=BinOperandRule(force=True)),
ast.LShift: BinOpRule("", r" \ll ", "", operand_right=BinOperandRule(force=True)),
ast.RShift: BinOpRule("", r" \gg ", "", operand_right=BinOperandRule(force=True)),
ast.BitAnd: BinOpRule("", r" \mathbin{\&} ", ""),
ast.BitXor: BinOpRule("", r" \oplus ", ""),
ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""),
}
# Typeset for BinOp of sets.
SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = {
**BIN_OP_RULES,
ast.Sub: BinOpRule(
"", r" \setminus ", "", operand_right=BinOperandRule(force=True)
),
ast.BitAnd: BinOpRule("", r" \cap ", ""),
ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""),
ast.BitOr: BinOpRule("", r" \cup ", ""),
}
UNARY_OPS: dict[type[ast.unaryop], str] = {
ast.Invert: r"\mathord{\sim} ",
ast.UAdd: "+", # Explicitly adds the $+$ operator.
ast.USub: "-",
ast.Not: r"\lnot ",
}
COMPARE_OPS: dict[type[ast.cmpop], str] = {
ast.Eq: "=",
ast.Gt: ">",
ast.GtE: r"\ge",
ast.In: r"\in",
ast.Is: r"\equiv",
ast.IsNot: r"\not\equiv",
ast.Lt: "<",
ast.LtE: r"\le",
ast.NotEq: r"\ne",
ast.NotIn: r"\notin",
}
# Typeset for Compare of sets.
SET_COMPARE_OPS: dict[type[ast.cmpop], str] = {
**COMPARE_OPS,
ast.Gt: r"\supset",
ast.GtE: r"\supseteq",
ast.Lt: r"\subset",
ast.LtE: r"\subseteq",
}
BOOL_OPS: dict[type[ast.boolop], str] = {
ast.And: r"\land",
ast.Or: r"\lor",
}
@dataclasses.dataclass(frozen=True)
class FunctionRule:
"""Codegen rules for functions.
Attributes:
left: LaTeX expression concatenated to the left-hand side of the arguments.
right: LaTeX expression concatenated to the right-hand side of the arguments.
is_unary: Whether the function is treated as a unary operator or not.
is_wrapped: Whether the resulting syntax is wrapped by brackets or not.
"""
left: str
right: str = ""
is_unary: bool = False
is_wrapped: bool = False
# name => left_syntax, right_syntax, is_wrapped
BUILTIN_FUNCS: dict[str, FunctionRule] = {
"abs": FunctionRule(r"\mathopen{}\left|", r"\mathclose{}\right|", is_wrapped=True),
"acos": FunctionRule(r"\arccos", is_unary=True),
"acosh": FunctionRule(r"\mathrm{arcosh}", is_unary=True),
"arccos": FunctionRule(r"\arccos", is_unary=True),
"arccot": FunctionRule(r"\mathrm{arccot}", is_unary=True),
"arccsc": FunctionRule(r"\mathrm{arccsc}", is_unary=True),
"arcosh": FunctionRule(r"\mathrm{arcosh}", is_unary=True),
"arcoth": FunctionRule(r"\mathrm{arcoth}", is_unary=True),
"arcsec": FunctionRule(r"\mathrm{arcsec}", is_unary=True),
"arcsch": FunctionRule(r"\mathrm{arcsch}", is_unary=True),
"arcsin": FunctionRule(r"\arcsin", is_unary=True),
"arctan": FunctionRule(r"\arctan", is_unary=True),
"arsech": FunctionRule(r"\mathrm{arsech}", is_unary=True),
"arsinh": FunctionRule(r"\mathrm{arsinh}", is_unary=True),
"artanh": FunctionRule(r"\mathrm{artanh}", is_unary=True),
"asin": FunctionRule(r"\arcsin", is_unary=True),
"asinh": FunctionRule(r"\mathrm{arsinh}", is_unary=True),
"atan": FunctionRule(r"\arctan", is_unary=True),
"atanh": FunctionRule(r"\mathrm{artanh}", is_unary=True),
"ceil": FunctionRule(
r"\mathopen{}\left\lceil", r"\mathclose{}\right\rceil", is_wrapped=True
),
"cos": FunctionRule(r"\cos", is_unary=True),
"cosh": FunctionRule(r"\cosh", is_unary=True),
"cot": FunctionRule(r"\cot", is_unary=True),
"coth": FunctionRule(r"\coth", is_unary=True),
"csc": FunctionRule(r"\csc", is_unary=True),
"csch": FunctionRule(r"\mathrm{csch}", is_unary=True),
"exp": FunctionRule(r"\exp", is_unary=True),
"fabs": FunctionRule(r"\mathopen{}\left|", r"\mathclose{}\right|", is_wrapped=True),
"factorial": FunctionRule("", "!", is_unary=True),
"floor": FunctionRule(
r"\mathopen{}\left\lfloor", r"\mathclose{}\right\rfloor", is_wrapped=True
),
"fsum": FunctionRule(r"\sum", is_unary=True),
"gamma": FunctionRule(r"\Gamma"),
"log": FunctionRule(r"\log", is_unary=True),
"log10": FunctionRule(r"\log_{10}", is_unary=True),
"log2": FunctionRule(r"\log_2", is_unary=True),
"prod": FunctionRule(r"\prod", is_unary=True),
"sec": FunctionRule(r"\sec", is_unary=True),
"sech": FunctionRule(r"\mathrm{sech}", is_unary=True),
"sin": FunctionRule(r"\sin", is_unary=True),
"sinh": FunctionRule(r"\sinh", is_unary=True),
"sqrt": FunctionRule(r"\sqrt{", "}", is_wrapped=True),
"sum": FunctionRule(r"\sum", is_unary=True),
"tan": FunctionRule(r"\tan", is_unary=True),
"tanh": FunctionRule(r"\tanh", is_unary=True),
}
MATH_SYMBOLS = {
"aleph",
"alpha",
"beta",
"beth",
"chi",
"daleth",
"delta",
"digamma",
"epsilon",
"eta",
"gamma",
"gimel",
"hbar",
"infty",
"iota",
"kappa",
"lambda",
"mu",
"nabla",
"nu",
"omega",
"phi",
"pi",
"psi",
"rho",
"sigma",
"tau",
"theta",
"upsilon",
"varepsilon",
"varkappa",
"varphi",
"varpi",
"varrho",
"varsigma",
"vartheta",
"xi",
"zeta",
"Delta",
"Gamma",
"Lambda",
"Omega",
"Phi",
"Pi",
"Psi",
"Sigma",
"Theta",
"Upsilon",
"Xi",
}
================================================
FILE: src/latexify/codegen/expression_rules_test.py
================================================
"""Tests for latexify.codegen.expression_rules."""
from __future__ import annotations
import ast
import pytest
from latexify.codegen import expression_rules
@pytest.mark.parametrize(
"node,precedence",
[
(
ast.Call(func=ast.Name(id="func", ctx=ast.Load()), args=[], keywords=[]),
expression_rules._CALL_PRECEDENCE,
),
(
ast.BinOp(
left=ast.Name(id="left", ctx=ast.Load()),
op=ast.Add(),
right=ast.Name(id="right", ctx=ast.Load()),
),
expression_rules._PRECEDENCES[ast.Add],
),
(
ast.UnaryOp(op=ast.UAdd(), operand=ast.Name(id="operand", ctx=ast.Load())),
expression_rules._PRECEDENCES[ast.UAdd],
),
(
ast.BoolOp(op=ast.And(), values=[ast.Name(id="value", ctx=ast.Load())]),
expression_rules._PRECEDENCES[ast.And],
),
(
ast.Compare(
left=ast.Name(id="left", ctx=ast.Load()),
ops=[ast.Eq()],
comparators=[ast.Name(id="right", ctx=ast.Load())],
),
expression_rules._PRECEDENCES[ast.Eq],
),
(ast.Name(id="name", ctx=ast.Load()), expression_rules._INF_PRECEDENCE),
(
ast.Attribute(
value=ast.Name(id="value", ctx=ast.Load()), attr="attr", ctx=ast.Load()
),
expression_rules._INF_PRECEDENCE,
),
],
)
def test_get_precedence(node: ast.AST, precedence: int) -> None:
assert expression_rules.get_precedence(node) == precedence
================================================
FILE: src/latexify/codegen/function_codegen.py
================================================
"""Codegen for single functions."""
from __future__ import annotations
import ast
import sys
from latexify import ast_utils, exceptions
from latexify.codegen import codegen_utils, expression_codegen, identifier_converter
class FunctionCodegen(ast.NodeVisitor):
"""Codegen for single functions.
This codegen works for Module with single FunctionDef node to generate a single
LaTeX expression of the given function.
"""
_identifier_converter: identifier_converter.IdentifierConverter
_use_signature: bool
def __init__(
self,
*,
use_math_symbols: bool = False,
use_signature: bool = True,
use_set_symbols: bool = False,
escape_underscores: bool = True,
) -> None:
"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_signature: Whether to add the function signature before the expression
or not.
use_set_symbols: Whether to use set symbols or not.
"""
self._expression_codegen = expression_codegen.ExpressionCodegen(
use_math_symbols=use_math_symbols,
use_set_symbols=use_set_symbols,
escape_underscores=escape_underscores,
)
self._identifier_converter = identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols, escape_underscores=escape_underscores
)
self._use_signature = use_signature
def generic_visit(self, node: ast.AST) -> str:
raise exceptions.LatexifyNotSupportedError(
f"Unsupported AST: {type(node).__name__}"
)
def visit_Module(self, node: ast.Module) -> str:
"""Visit a Module node."""
return self.visit(node.body[0])
def visit_FunctionDef(self, node: ast.FunctionDef) -> str:
"""Visit a FunctionDef node."""
# Function name
name_str = self._identifier_converter.convert(node.name)[0]
# Arguments
arg_strs = [
self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args
]
body_strs: list[str] = []
# Assignment statements (if any): x = ...
for child in node.body[:-1]:
if isinstance(child, ast.Expr) and ast_utils.is_constant(child.value):
continue
if not isinstance(child, ast.Assign):
raise exceptions.LatexifyNotSupportedError(
"Codegen supports only Assign nodes in multiline functions, "
f"but got: {type(child).__name__}"
)
body_strs.append(self.visit(child))
return_stmt = node.body[-1]
if sys.version_info.minor >= 10:
if not isinstance(return_stmt, (ast.Return, ast.If, ast.Match)):
raise exceptions.LatexifySyntaxError(
f"Unsupported last statement: {type(return_stmt).__name__}"
)
else:
if not isinstance(return_stmt, (ast.Return, ast.If)):
raise exceptions.LatexifySyntaxError(
f"Unsupported last statement: {type(return_stmt).__name__}"
)
# Function signature: f(x, ...)
signature_str = name_str + "(" + ", ".join(arg_strs) + ")"
# Function definition: f(x, ...) \triangleq ...
return_str = self.visit(return_stmt)
if self._use_signature:
return_str = signature_str + " = " + return_str
if not body_strs:
# Only the definition.
return return_str
# Definition with several assignments. Wrap all statements with array.
body_strs.append(return_str)
return r"\begin{array}{l} " + r" \\ ".join(body_strs) + r" \end{array}"
def visit_Assign(self, node: ast.Assign) -> str:
"""Visit an Assign node."""
operands: list[str] = [self._expression_codegen.visit(t) for t in node.targets]
operands.append(self._expression_codegen.visit(node.value))
return " = ".join(operands)
def visit_Return(self, node: ast.Return) -> str:
"""Visit a Return node."""
return (
self._expression_codegen.visit(node.value)
if node.value is not None
else codegen_utils.convert_constant(None)
)
def visit_If(self, node: ast.If) -> str:
"""Visit an If node."""
latex = r"\left\{ \begin{array}{ll} "
current_stmt: ast.stmt = node
while isinstance(current_stmt, ast.If):
if len(current_stmt.body) != 1 or len(current_stmt.orelse) != 1:
raise exceptions.LatexifySyntaxError(
"Multiple statements are not supported in If nodes."
)
cond_latex = self._expression_codegen.visit(current_stmt.test)
true_latex = self.visit(current_stmt.body[0])
latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ "
current_stmt = current_stmt.orelse[0]
latex += self.visit(current_stmt)
return latex + r", & \mathrm{otherwise} \end{array} \right."
def visit_Match(self, node: ast.Match) -> str:
"""Visit a Match node"""
if not (
len(node.cases) >= 2
and isinstance(node.cases[-1].pattern, ast.MatchAs)
and node.cases[-1].pattern.name is None
):
raise exceptions.LatexifySyntaxError(
"Match statement must contain the wildcard."
)
subject_latex = self._expression_codegen.visit(node.subject)
case_latexes: list[str] = []
for i, case in enumerate(node.cases):
if len(case.body) != 1 or not isinstance(case.body[0], ast.Return):
raise exceptions.LatexifyNotSupportedError(
"Match cases must contain exactly 1 return statement."
)
if i < len(node.cases) - 1:
body_latex = self.visit(case.body[0])
cond_latex = self.visit(case.pattern)
case_latexes.append(
body_latex + r", & \mathrm{if} \ " + subject_latex + cond_latex
)
else:
case_latexes.append(
self.visit(node.cases[-1].body[0]) + r", & \mathrm{otherwise}"
)
return (
r"\left\{ \begin{array}{ll} "
+ r" \\ ".join(case_latexes)
+ r" \end{array} \right."
)
def visit_MatchValue(self, node: ast.MatchValue) -> str:
"""Visit a MatchValue node"""
latex = self._expression_codegen.visit(node.value)
return " = " + latex
================================================
FILE: src/latexify/codegen/function_codegen_match_test.py
================================================
"""Tests for FunctionCodegen with match statements."""
from __future__ import annotations
import ast
import textwrap
import pytest
from latexify import exceptions, test_utils
from latexify.codegen import function_codegen
@test_utils.require_at_least(10)
def test_functiondef_match() -> None:
tree = ast.parse(
textwrap.dedent(
"""
def f(x):
match x:
case 0:
return 1
case _:
return 3 * x
"""
)
)
expected = (
r"f(x) ="
r" \left\{ \begin{array}{ll}"
r" 1, & \mathrm{if} \ x = 0 \\"
r" 3 x, & \mathrm{otherwise}"
r" \end{array} \right."
)
assert function_codegen.FunctionCodegen().visit(tree) == expected
@test_utils.require_at_least(10)
def test_matchvalue() -> None:
tree = ast.parse(
textwrap.dedent(
"""
match x:
case 0:
return 1
case _:
return 2
"""
)
).body[0]
expected = (
r"\left\{ \begin{array}{ll}"
r" 1, & \mathrm{if} \ x = 0 \\"
r" 2, & \mathrm{otherwise}"
r" \end{array} \right."
)
assert function_codegen.FunctionCodegen().visit(tree) == expected
@test_utils.require_at_least(10)
def test_multiple_matchvalue() -> None:
tree = ast.parse(
textwrap.dedent(
"""
match x:
case 0:
return 1
case 1:
return 2
case _:
return 3
"""
)
).body[0]
expected = (
r"\left\{ \begin{array}{ll}"
r" 1, & \mathrm{if} \ x = 0 \\"
r" 2, & \mathrm{if} \ x = 1 \\"
r" 3, & \mathrm{otherwise}"
r" \end{array} \right."
)
assert function_codegen.FunctionCodegen().visit(tree) == expected
@test_utils.require_at_least(10)
def test_single_matchvalue_no_wildcards() -> None:
tree = ast.parse(
textwrap.dedent(
"""
match x:
case 0:
return 1
"""
)
).body[0]
with pytest.raises(
exceptions.LatexifySyntaxError,
match=r"^Match statement must contain the wildcard\.$",
):
function_codegen.FunctionCodegen().visit(tree)
@test_utils.require_at_least(10)
def test_multiple_matchvalue_no_wildcards() -> None:
tree = ast.parse(
textwrap.dedent(
"""
match x:
case 0:
return 1
case 1:
return 2
"""
)
).body[0]
with pytest.raises(
exceptions.LatexifySyntaxError,
match=r"^Match statement must contain the wildcard\.$",
):
function_codegen.FunctionCodegen().visit(tree)
@test_utils.require_at_least(10)
def test_matchas_nonempty() -> None:
tree = ast.parse(
textwrap.dedent(
"""
match x:
case [x] as y:
return 1
case _:
return 2
"""
)
).body[0]
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match=r"^Unsupported AST: MatchAs$",
):
function_codegen.FunctionCodegen().visit(tree)
@test_utils.require_at_least(10)
def test_matchvalue_no_return() -> None:
tree = ast.parse(
textwrap.dedent(
"""
match x:
case 0:
x = 5
case _:
return 0
"""
)
).body[0]
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match=r"^Match cases must contain exactly 1 return statement\.$",
):
function_codegen.FunctionCodegen().visit(tree)
@test_utils.require_at_least(10)
def test_matchvalue_mutliple_statements() -> None:
tree = ast.parse(
textwrap.dedent(
"""
match x:
case 0:
x = 5
return 1
case _:
return 0
"""
)
).body[0]
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match=r"^Match cases must contain exactly 1 return statement\.$",
):
function_codegen.FunctionCodegen().visit(tree)
================================================
FILE: src/latexify/codegen/function_codegen_test.py
================================================
"""Tests for latexify.codegen.function_codegen."""
from __future__ import annotations
import ast
import textwrap
import pytest
from latexify import exceptions
from latexify.codegen import function_codegen
def test_generic_visit() -> None:
class UnknownNode(ast.AST):
pass
with pytest.raises(
exceptions.LatexifyNotSupportedError,
match=r"^Unsupported AST: UnknownNode$",
):
function_codegen.FunctionCodegen().visit(UnknownNode())
def test_visit_functiondef_use_signature() -> None:
tree = ast.parse(
textwrap.dedent(
"""
def f(x):
return x
"""
)
).body[0]
assert isinstance(tree, ast.FunctionDef)
latex_without_flag = "x"
latex_with_flag = r"f(x) = x"
assert function_codegen.FunctionCodegen().visit(tree) == latex_with_flag
assert (
function_codegen.FunctionCodegen(use_signature=False).visit(tree)
== latex_without_flag
)
assert (
function_codegen.FunctionCodegen(use_signature=True).visit(tree)
== latex_with_flag
)
def test_visit_functiondef_ignore_docstring() -> None:
tree = ast.parse(
textwrap.dedent(
"""
def f(x):
'''docstring'''
return x
"""
)
).body[0]
assert isinstance(tree, ast.FunctionDef)
latex = r"f(x) = x"
assert function_codegen.FunctionCodegen().visit(tree) == latex
def test_visit_functiondef_ignore_multiple_constants() -> None:
tree = ast.parse(
textwrap.dedent(
"""
def f(x):
'''docstring'''
3
True
return x
"""
)
).body[0]
assert isinstance(tree, ast.FunctionDef)
latex = r"f(x) = x"
assert function_codegen.FunctionCodegen().visit(tree) == latex
================================================
FILE: src/latexify/codegen/identifier_converter.py
================================================
"""Utility to convert identifiers."""
from __future__ import annotations
from latexify.codegen import expression_rules
class IdentifierConverter:
r"""Converts Python identifiers to appropriate LaTeX expression.
This converter applies following rules:
- `foo` --> `\foo`, if `use_math_symbols == True` and the given identifier
matches a supported math symbol name.
- `x` --> `x`, if the given identifier is exactly 1 character (except `_`)
- `foo_bar` --> `\mathrm{foo\_bar}`, otherwise.
"""
_use_math_symbols: bool
_use_mathrm: bool
_escape_underscores: bool
def __init__(
self,
*,
use_math_symbols: bool,
use_mathrm: bool = True,
escape_underscores: bool = True,
) -> None:
r"""Initializer.
Args:
use_math_symbols: Whether to convert identifiers with math symbol names to
appropriate LaTeX command.
use_mathrm: Whether to wrap the resulting expression by \mathrm, if
applicable.
escape_underscores: Whether to prefix any underscores in identifiers with
'\\', disable to allow subscripts in generated latex.
"""
self._use_math_symbols = use_math_symbols
self._use_mathrm = use_mathrm
self._escape_underscores = escape_underscores
def convert(self, name: str) -> tuple[str, bool]:
"""Converts Python identifier to LaTeX expression.
Args:
name: Identifier name.
Returns:
Tuple of following values:
- latex: Corresponding LaTeX expression.
- is_single_character: Whether `latex` can be treated as a single
character or not.
Raises:
LatexifyError: Resulting latex is not valid. This most likely occurs where
the symbol starts or ends with an underscore, and escape_underscores=False.
"""
if not self._escape_underscores and "_" in name:
# Check if we are going to generate an invalid Latex string. Better to
# raise an exception here than have the resulting Latex fail to
# compile/display
name_splits = name.split("_")
if not all(name_splits):
raise ValueError(
"Neither preceding/trailing underscores nor double underscores is "
f"allowed by the `escape_underscores` option, but got: {name}"
)
elems = [
IdentifierConverter(
use_math_symbols=self._use_math_symbols,
use_mathrm=False,
escape_underscores=True,
).convert(n)[0]
for n in name_splits
]
# Wrap sub identifiers in nested braces
name = "_{".join(elems) + "}" * (len(elems) - 1)
if self._use_math_symbols and name in expression_rules.MATH_SYMBOLS:
return "\\" + name, True
if len(name) == 1 and name != "_":
return name, True
escaped = name.replace("_", r"\_") if self._escape_underscores else name
wrapped = rf"\mathrm{{{escaped}}}" if self._use_mathrm else escaped
return wrapped, False
================================================
FILE: src/latexify/codegen/identifier_converter_test.py
================================================
"""Tests for latexify.codegen.identifier_converter."""
from __future__ import annotations
import pytest
from latexify.codegen import identifier_converter
@pytest.mark.parametrize(
"name,use_math_symbols,use_mathrm,escape_underscores,expected",
[
("a", False, True, True, ("a", True)),
("_", False, True, True, (r"\mathrm{\_}", False)),
("aa", False, True, True, (r"\mathrm{aa}", False)),
("a1", False, True, True, (r"\mathrm{a1}", False)),
("a_", False, True, True, (r"\mathrm{a\_}", False)),
("_a", False, True, True, (r"\mathrm{\_a}", False)),
("_1", False, True, True, (r"\mathrm{\_1}", False)),
("__", False, True, True, (r"\mathrm{\_\_}", False)),
("a_a", False, True, True, (r"\mathrm{a\_a}", False)),
("a__", False, True, True, (r"\mathrm{a\_\_}", False)),
("a_1", False, True, True, (r"\mathrm{a\_1}", False)),
("alpha", False, True, True, (r"\mathrm{alpha}", False)),
("alpha", True, True, True, (r"\alpha", True)),
("alphabet", True, True, True, (r"\mathrm{alphabet}", False)),
("foo", False, True, True, (r"\mathrm{foo}", False)),
("foo", True, True, True, (r"\mathrm{foo}", False)),
("foo", True, False, True, (r"foo", False)),
("aa", False, True, False, (r"\mathrm{aa}", False)),
("a_a", False, True, False, (r"\mathrm{a_{a}}", False)),
("a_1", False, True, False, (r"\mathrm{a_{1}}", False)),
("alpha", True, False, False, (r"\alpha", True)),
("alpha_1", True, False, False, (r"\alpha_{1}", False)),
("x_alpha", True, False, False, (r"x_{\alpha}", False)),
("x_alpha_beta", True, False, False, (r"x_{\alpha_{\beta}}", False)),
("alpha_beta", True, False, False, (r"\alpha_{\beta}", False)),
],
)
def test_identifier_converter(
name: str,
use_math_symbols: bool,
use_mathrm: bool,
escape_underscores: bool,
expected: tuple[str, bool],
) -> None:
assert (
identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols,
use_mathrm=use_mathrm,
escape_underscores=escape_underscores,
).convert(name)
== expected
)
@pytest.mark.parametrize(
"name,use_math_symbols,use_mathrm,escape_underscores",
[
("_", False, True, False),
("a_", False, True, False),
("_a", False, True, False),
("_1", False, True, False),
("__", False, True, False),
("a__", False, True, False),
("alpha_", True, False, False),
("_alpha", True, False, False),
("x__alpha", True, False, False),
],
)
def test_identifier_converter_failure(
name: str,
use_math_symbols: bool,
use_mathrm: bool,
escape_underscores: bool,
) -> None:
with pytest.raises(ValueError):
identifier_converter.IdentifierConverter(
use_math_symbols=use_math_symbols,
use_mathrm=use_mathrm,
escape_underscores=escape_underscores,
).convert(name)
================================================
FILE: src/latexify/codegen/latex.py
================================================
"""Definition of Latex."""
from __future__ import annotations
from collections.abc import Iterable
from typing import Union
LatexLike = Union[str, "Latex"]
class Latex:
"""LaTeX expression string for ease of writing the codegen source."""
_raw: str
def __init__(self, raw: str) -> None:
"""Initializer.
Args:
raw: Direct string of the underlying expression.
"""
self._raw = raw
def __eq__(self, other: object) -> bool:
"""Checks equality.
Args:
other: Other object to check equality.
Returns:
True if other is Latex and the underlying expression is the same as self,
False otherwise.
"""
return isinstance(other, Latex) and other._raw == self._raw
def __str__(self) -> str:
"""Returns the underlying expression.
Returns:
The underlying expression.
"""
return self._raw
def __add__(self, other: object) -> Latex:
"""Concatenates two expressions.
Args:
other: The expression to be concatenated to the right side of self.
Returns:
A new expression: "{self}{other}"
"""
if isinstance(other, str):
return Latex(self._raw + other)
if isinstance(other, Latex):
return Latex(self._raw + other._raw)
raise ValueError("Unsupported operation.")
def __radd__(self, other: object) -> Latex:
"""Concatenates two expressions.
Args:
other: The expression to be concatenated to the left side of self.
Returns:
A new expression: "{other}{self}"
"""
if isinstance(other, str):
return Latex(other + self._raw)
if isinstance(other, Latex):
return Latex(other._raw + self._raw)
raise ValueError("Unsupported operation.")
@staticmethod
def opt(src: LatexLike) -> Latex:
"""Wraps the expression by "[" and "]".
This wrapping is used when the expression needs to be wrapped as an optional
argument of the environment.
Args:
src: Original expression.
Returns:
A new expression with surrounding brackets.
"""
return Latex("[" + str(src) + "]")
@staticmethod
def arg(src: LatexLike) -> Latex:
"""Wraps the expression by "{" and "}".
This wrapping is used when the expression needs to be wrapped as an argument of
other expressions.
Args:
src: Original expression.
Returns:
A new expression with surrounding brackets.
"""
return Latex("{" + str(src) + "}")
@staticmethod
def paren(src: LatexLike) -> Latex:
"""Adds surrounding parentheses: "(" and ")".
Args:
src: Original expression.
Returns:
A new expression with surrounding brackets.
"""
return Latex(r"\mathopen{}\left( " + str(src) + r" \mathclose{}\right)")
@staticmethod
def curly(src: LatexLike) -> Latex:
"""Adds surrounding curly brackets: "\\{" and "\\}".
Args:
src: Original expression.
Returns:
A new expression with surrounding brackets.
"""
return Latex(r"\mathopen{}\left\{ " + str(src) + r" \mathclose{}\right\}")
@staticmethod
def square(src: LatexLike) -> Latex:
"""Adds surrounding square brackets: "[" and "]".
Args:
src: Original expression.
Returns:
A new expression with surrounding brackets.
"""
return Latex(r"\mathopen{}\left[ " + str(src) + r" \mathclose{}\right]")
@staticmethod
def command(
name: str,
*,
options: list[LatexLike] | None = None,
args: list[LatexLike] | None = None,
) -> Latex:
"""Makes a Latex command expression.
Args:
name: Name of the command.
options: List of optional arguments.
args: List of arguments.
Returns:
A new expression.
"""
elms: list[LatexLike] = [rf"\{name}"]
if options is not None:
elms += [Latex.opt(x) for x in options]
if args is not None:
elms += [Latex.arg(x) for x in args]
return Latex.join("", elms)
@staticmethod
def environment(
name: str,
*,
options: list[LatexLike] | None = None,
args: list[LatexLike] | None = None,
content: LatexLike | None = None,
) -> Latex:
"""Makes a Latex environment expression.
Args:
name: Name of the environment.
options: List of optional arguments.
args: List of arguments.
content: Inner content of the environment.
Returns:
A new expression.
"""
begin_elms: list[LatexLike] = [rf"\begin{{{name}}}"]
if options is not None:
begin_elms += [Latex.opt(x) for x in options]
if args is not None:
begin_elms += [Latex.arg(x) for x in args]
env_elms: list[LatexLike] = [Latex.join("", begin_elms)]
if content is not None:
env_elms.append(content)
env_elms.append(rf"\end{{{name}}}")
return Latex.join(" ", env_elms)
@staticmethod
def join(separator: LatexLike, elements: Iterable[LatexLike]) -> Latex:
"""Joins given sequence.
Args:
separator: Expression of the separator between each element.
elements: Iterable of expressions to be joined.
Returns:
A new Latex: "{e[0]}{s}{e[1]}{s}...{s}{e[-1]}"
where s == separator, and e == elements.
"""
return Latex(str(separator).join(str(x) for x in elements))
================================================
FILE: src/latexify/codegen/latex_test.py
================================================
"""Tests for latexify.codegen.latex."""
from __future__ import annotations
# Ignores [22-imports] for convenience.
from latexify.codegen.latex import Latex
def test_eq() -> None:
assert Latex("foo") == Latex("foo")
assert Latex("foo") != "foo"
assert Latex("foo") != Latex("bar")
def test_str() -> None:
assert str(Latex("foo")) == "foo"
def test_add() -> None:
assert Latex("foo") + "bar" == Latex("foobar")
assert "foo" + Latex("bar") == Latex("foobar")
assert Latex("foo") + Latex("bar") == Latex("foobar")
def test_opt() -> None:
assert Latex.opt("foo") == Latex("[foo]")
assert Latex.opt(Latex("foo")) == Latex("[foo]")
def test_arg() -> None:
assert Latex.arg("foo") == Latex("{foo}")
assert Latex.arg(Latex("foo")) == Latex("{foo}")
def test_paren() -> None:
assert Latex.paren("foo") == Latex(r"\mathopen{}\left( foo \mathclose{}\right)")
assert Latex.paren(Latex("foo")) == Latex(
r"\mathopen{}\left( foo \mathclose{}\right)"
)
def test_curly() -> None:
assert Latex.curly("foo") == Latex(r"\mathopen{}\left\{ foo \mathclose{}\right\}")
assert Latex.curly(Latex("foo")) == Latex(
r"\mathopen{}\left\{ foo \mathclose{}\right\}"
)
def test_square() -> None:
assert Latex.square("foo") == Latex(r"\mathopen{}\left[ foo \mathclose{}\right]")
assert Latex.square(Latex("foo")) == Latex(
r"\mathopen{}\left[ foo \mathclose{}\right]"
)
def test_command() -> None:
assert Latex.command("a") == Latex(r"\a")
assert Latex.command("a", options=[]) == Latex(r"\a")
assert Latex.command("a", options=["b"]) == Latex(r"\a[b]")
assert Latex.command("a", options=[Latex("b")]) == Latex(r"\a[b]")
assert Latex.command("a", options=["b", "c"]) == Latex(r"\a[b][c]")
assert Latex.command("a", args=[]) == Latex(r"\a")
assert Latex.command("a", args=["b"]) == Latex(r"\a{b}")
assert Latex.command("a", args=[Latex("b")]) == Latex(r"\a{b}")
assert Latex.command("a", args=["b", "c"]) == Latex(r"\a{b}{c}")
assert Latex.command("a", options=["b"], args=["c"]) == Latex(r"\a[b]{c}")
def test_environment() -> None:
assert Latex.environment("a") == Latex(r"\begin{a} \end{a}")
assert Latex.environment("a", options=[]) == Latex(r"\begin{a} \end{a}")
assert Latex.environment("a", options=["b"]) == Latex(r"\begin{a}[b] \end{a}")
assert Latex.environment("a", options=[Latex("b")]) == Latex(
r"\begin{a}[b] \end{a}"
)
assert Latex.environment("a", options=["b", "c"]) == Latex(
r"\begin{a}[b][c] \end{a}"
)
assert Latex.environment("a", args=[]) == Latex(r"\begin{a} \end{a}")
assert Latex.environment("a", args=["b"]) == Latex(r"\begin{a}{b} \end{a}")
assert Latex.environment("a", args=[Latex("b")]) == Latex(r"\begin{a}{b} \end{a}")
assert Latex.environment("a", args=["b", "c"]) == Latex(r"\begin{a}{b}{c} \end{a}")
assert Latex.environment("a", content="b") == Latex(r"\begin{a} b \end{a}")
assert Latex.environment("a", content=Latex("b")) == Latex(r"\begin{a} b \end{a}")
assert Latex.environment("a", options=["b"], args=["c"]) == Latex(
r"\begin{a}[b]{c} \end{a}"
)
assert Latex.environment("a", options=["b"], content="c") == Latex(
r"\begin{a}[b] c \end{a}"
)
assert Latex.environment("a", args=["b"], content="c") == Latex(
r"\begin{a}{b} c \end{a}"
)
assert Latex.environment("a", options=["b"], args=["c"], content="d") == Latex(
r"\begin{a}[b]{c} d \end{a}"
)
def test_join() -> None:
assert Latex.join(":", []) == Latex("")
assert Latex.join(":", ["foo"]) == Latex("foo")
assert Latex.join(":", [Latex("foo")]) == Latex("foo")
assert Latex.join(":", [Latex("foo"), "bar"]) == Latex("foo:bar")
assert Latex.join(":", ["foo", Latex("bar")]) == Latex("foo:bar")
assert Latex.join(":", [Latex("foo"), Latex("bar")]) == Latex("foo:bar")
assert Latex.join(":", ()) == Latex("")
assert Latex.join(":", ("foo",)) == Latex("foo")
assert Latex.join(":", (Latex("foo"),)) == Latex("foo")
assert Latex.join(":", (Latex("foo"), "bar")) == Latex("foo:bar")
assert Latex.join(":", ("foo", Latex("bar"))) == Latex("foo:bar")
assert Latex.join(":", (Latex("foo"), Latex("bar"))) == Latex("foo:bar")
assert Latex.join(":", (str(x) for x in range(3))) == Latex("0:1:2")
assert Latex.join(":", (Latex(str(x)) for x in range(3))) == Latex("0:1:2")
================================================
FILE: src/latexify/config.py
================================================
"""Definition of the Config class."""
from __future__ import annotations
import dataclasses
from typing import Any
@dataclasses.dataclass(frozen=True)
class Config:
"""Configurations to control the behavior of latexify.
Attributes:
expand_functions: If set, the names of the functions to expand.
identifiers: If set, the mapping to replace identifier names in the
function. Keys are the original names of the identifiers,
and corresponding values are the replacements.
Both keys and values have to represent valid Python identifiers:
^[A-Za-z_][A-Za-z0-9_]*$
prefixes: Prefixes of identifiers to trim. E.g., if "foo.bar" in prefixes, all
identifiers with the form "foo.bar.suffix" will be replaced to "suffix"
reduce_assignments: If True, assignment statements are used to synthesize
the final expression.
use_math_symbols: Whether to convert identifiers with a math symbol surface
(e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha").
use_set_symbols: Whether to use set symbols or not.
use_signature: Whether to add the function signature before the expression
or not.
"""
expand_functions: set[str] | None
identifiers: dict[str, str] | None
prefixes: set[str] | None
reduce_assignments: bool
use_math_symbols: bool
use_set_symbols: bool
use_signature: bool
escape_underscores: bool
def merge(self, *, config: Config | None = None, **kwargs) -> Config:
"""Merge configuration based on old configuration and field values.
Args:
config: If None, the merged one will merge defaults and field values,
instead of merging old configuration and field values.
**kwargs: Members to modify. This value precedes both self and config.
Returns:
A new Config object
"""
def merge_field(name: str) -> Any:
# Precedence: kwargs -> config -> self
arg = kwargs.get(name)
if arg is None:
if config is not None:
arg = getattr(config, name)
else:
arg = getattr(self, name)
return arg
return Config(**{f.name: merge_field(f.name) for f in dataclasses.fields(self)})
@staticmethod
def defaults() -> Config:
"""Generates a Config with default values.
Returns:
A new Config with default values
"""
return Config(
expand_functions=None,
identifiers=None,
prefixes=None,
reduce_assignments=False,
use_math_symbols=False,
use_set_symbols=False,
use_signature=True,
escape_underscores=True,
)
================================================
FILE: src/latexify/exceptions.py
================================================
"""Exceptions used in Latexify."""
class LatexifyError(Exception):
"""Base class of all Latexify exceptions.
Subclasses of this exception does not mean incorrect use of the library by the user
at the interface level. These exceptions inform users that Latexify went into
something wrong during processing the given functions.
These exceptions are usually captured by the frontend functions (e.g., `with_latex`)
to prevent destroying the entire program.
Errors caused by wrong inputs should raise built-in exceptions.
"""
...
class LatexifyNotSupportedError(LatexifyError):
"""Some subtree in the AST is not supported by the current implementation.
This error is raised when the library discovered incompatible syntaxes due to lack
of the implementation. Possibly this error would be resolved in the future.
"""
...
class LatexifySyntaxError(LatexifyError):
"""Some subtree in the AST is not supported.
This error is raised when the library discovered syntaxes that are not possible to
be processed anymore. This error is essential, and wouldn't be resolved in the
future.
"""
...
================================================
FILE: src/latexify/frontend.py
================================================
"""Frontend interfaces of latexify."""
from __future__ import annotations
from collections.abc import Callable
from typing import Any, overload
from latexify import ipython_wrappers
@overload
def algorithmic(
fn: Callable[..., Any], **kwargs: Any
) -> ipython_wrappers.LatexifiedAlgorithm: ...
@overload
def algorithmic(
**kwargs: Any,
) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm]: ...
def algorithmic(
fn: Callable[..., Any] | None = None, **kwargs: Any
) -> (
ipython_wrappers.LatexifiedAlgorithm
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm]
):
"""Attach LaTeX pretty-printing to the given function.
This function works with or without specifying the target function as the
positional argument. The following two syntaxes works similarly.
- latexify.algorithmic(alg, **kwargs)
- latexify.algorithmic(**kwargs)(alg)
Args:
fn: Callable to be wrapped.
**kwargs: Arguments to control behavior. See also get_latex().
Returns:
- If `fn` is passed, returns the wrapped function.
- Otherwise, returns the wrapper function with given settings.
"""
if fn is not None:
return ipython_wrappers.LatexifiedAlgorithm(fn, **kwargs)
def wrapper(f):
return ipython_wrappers.LatexifiedAlgorithm(f, **kwargs)
return wrapper
@overload
def function(
fn: Callable[..., Any], **kwargs: Any
) -> ipython_wrappers.LatexifiedFunction: ...
@overload
def function(
**kwargs: Any,
) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]: ...
def function(
fn: Callable[..., Any] | None = None, **kwargs: Any
) -> (
ipython_wrappers.LatexifiedFunction
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]
):
"""Attach LaTeX pretty-printing to the given function.
This function works with or without specifying the target function as the positional
argument. The following two syntaxes works similarly.
- latexify.function(fn, **kwargs)
- latexify.function(**kwargs)(fn)
Args:
fn: Callable to be wrapped.
**kwargs: Arguments to control behavior. See also get_latex().
Returns:
- If `fn` is passed, returns the wrapped function.
- Otherwise, returns the wrapper function with given settings.
"""
if fn is not None:
return ipython_wrappers.LatexifiedFunction(fn, **kwargs)
def wrapper(f):
return ipython_wrappers.LatexifiedFunction(f, **kwargs)
return wrapper
@overload
def expression(
fn: Callable[..., Any], **kwargs: Any
) -> ipython_wrappers.LatexifiedFunction: ...
@overload
def expression(
**kwargs: Any,
) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]: ...
def expression(
fn: Callable[..., Any] | None = None, **kwargs: Any
) -> (
ipython_wrappers.LatexifiedFunction
| Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]
):
"""Attach LaTeX pretty-printing to the given function.
This function is a shortcut for `latexify.function` with the default parameter
`use_signature=False`.
"""
kwargs["use_signature"] = kwargs.get("use_signature", False)
if fn is not None:
return ipython_wrappers.LatexifiedFunction(fn, **kwargs)
def wrapper(f):
return ipython_wrappers.LatexifiedFunction(f, **kwargs)
return wrapper
================================================
FILE: src/latexify/frontend_test.py
================================================
"""Tests for latexify.frontend."""
from __future__ import annotations
from latexify import frontend
def test_function() -> None:
def f(x):
return x
latex_without_flag = "x"
latex_with_flag = r"f(x) = x"
# Checks the syntax:
# @function
# def fn(...):
# ...
latexified = frontend.function(f)
assert str(latexified) == latex_with_flag
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_with_flag} $$"
# Checks the syntax:
# @function(**kwargs)
# def fn(...):
# ...
latexified = frontend.function(use_signature=False)(f)
assert str(latexified) == latex_without_flag
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_without_flag} $$"
# Checks the syntax:
# def fn(...):
# ...
# latexified = function(fn, **kwargs)
latexified = frontend.function(f, use_signature=False)
assert str(latexified) == latex_without_flag
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_without_flag} $$"
def test_expression() -> None:
def f(x):
return x
latex_without_flag = "x"
latex_with_flag = r"f(x) = x"
# Checks the syntax:
# @expression
# def fn(...):
# ...
latexified = frontend.expression(f)
assert str(latexified) == latex_without_flag
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_without_flag} $$"
# Checks the syntax:
# @expression(**kwargs)
# def fn(...):
# ...
latexified = frontend.expression(use_signature=True)(f)
assert str(latexified) == latex_with_flag
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_with_flag} $$"
# Checks the syntax:
# def fn(...):
# ...
# latexified = expression(fn, **kwargs)
latexified = frontend.expression(f, use_signature=True)
assert str(latexified) == latex_with_flag
assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_with_flag} $$"
================================================
FILE: src/latexify/generate_latex.py
================================================
"""Generate LaTeX code."""
from __future__ import annotations
import enum
from collections.abc import Callable
from typing import Any
from latexify import codegen
from latexify import config as cfg
from latexify import parser, transformers
class Style(enum.Enum):
"""The style of the generated LaTeX."""
ALGORITHMIC = "algorithmic"
FUNCTION = "function"
IPYTHON_ALGORITHMIC = "ipython-algorithmic"
def get_latex(
fn: Callable[..., Any],
*,
style: Style = Style.FUNCTION,
config: cfg.Config | None = None,
**kwargs,
) -> str:
"""Obtains LaTeX description from the function's source.
Args:
fn: Reference to a function to analyze.
style: Style of the LaTeX description, the default is FUNCTION.
config: Use defined Config object, if it is None, it will be automatic assigned
with default value.
**kwargs: Dict of Config field values that could be defined individually
by users.
Returns:
Generated LaTeX description.
Raises:
latexify.exceptions.LatexifyError: Something went wrong during conversion.
"""
merged_config = cfg.Config.defaults().merge(config=config, **kwargs)
# Obtains the source AST.
tree = parser.parse_function(fn)
# Mandatory AST Transformation.
tree = transformers.AugAssignReplacer().visit(tree)
# Conditional AST transformation.
if merged_config.prefixes is not None:
tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree)
if merged_config.identifiers is not None:
tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree)
if merged_config.reduce_assignments:
tree = transformers.DocstringRemover().visit(tree)
tree = transformers.AssignmentReducer().visit(tree)
if merged_config.expand_functions is not None:
tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree)
# Generates LaTeX.
if style == Style.ALGORITHMIC:
return codegen.AlgorithmicCodegen(
use_math_symbols=merged_config.use_math_symbols,
use_set_symbols=merged_config.use_set_symbols,
escape_underscores=merged_config.escape_underscores,
).visit(tree)
elif style == Style.FUNCTION:
return codegen.FunctionCodegen(
use_math_symbols=merged_config.use_math_symbols,
use_signature=merged_config.use_signature,
use_set_symbols=merged_config.use_set_symbols,
escape_underscores=merged_config.escape_underscores,
).visit(tree)
elif style == Style.IPYTHON_ALGORITHMIC:
return codegen.IPythonAlgorithmicCodegen(
use_math_symbols=merged_config.use_math_symbols,
use_set_symbols=merged_config.use_set_symbols,
escape_underscores=merged_config.escape_underscores,
).visit(tree)
raise ValueError(f"Unrecognized style: {style}")
================================================
FILE: src/latexify/generate_latex_test.py
================================================
"""Tests for latexify.generate_latex."""
from __future__ import annotations
from latexify import generate_latex
def test_get_latex_identifiers() -> None:
def myfn(myvar):
return 3 * myvar
identifiers = {"myfn": "f", "myvar": "x"}
latex_without_flag = r"\mathrm{myfn}(\mathrm{myvar}) = 3 \mathrm{myvar}"
latex_with_flag = r"f(x) = 3 x"
assert generate_latex.get_latex(myfn) == latex_without_flag
assert generate_latex.get_latex(myfn, identifiers=identifiers) == latex_with_flag
def test_get_latex_prefixes() -> None:
abc = object()
def f(x):
return abc.d + x.y.z.e
latex_without_flag = r"f(x) = \mathrm{abc}.d + x.y.z.e"
latex_with_flag1 = r"f(x) = d + x.y.z.e"
latex_with_flag2 = r"f(x) = \mathrm{abc}.d + y.z.e"
latex_with_flag3 = r"f(x) = \mathrm{abc}.d + z.e"
latex_with_flag4 = r"f(x) = d + e"
assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, prefixes=set()) == latex_without_flag
assert generate_latex.get_latex(f, prefixes={"abc"}) == latex_with_flag1
assert generate_latex.get_latex(f, prefixes={"x"}) == latex_with_flag2
assert generate_latex.get_latex(f, prefixes={"x.y"}) == latex_with_flag3
assert generate_latex.get_latex(f, prefixes={"abc", "x.y.z"}) == latex_with_flag4
assert (
generate_latex.get_latex(f, prefixes={"abc", "x", "x.y.z"}) == latex_with_flag4
)
def test_get_latex_reduce_assignments() -> None:
def f(x):
y = 3 * x
return y
latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}"
latex_with_flag = r"f(x) = 3 x"
assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag
def test_get_latex_reduce_assignments_with_docstring() -> None:
def f(x):
"""DocstringRemover is required."""
y = 3 * x
return y
latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}"
latex_with_flag = r"f(x) = 3 x"
assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag
def test_get_latex_reduce_assignments_with_aug_assign() -> None:
def f(x):
y = 3
y *= x
return y
latex_without_flag = r"\begin{array}{l} y = 3 \\ y = y x \\ f(x) = y \end{array}"
latex_with_flag = r"f(x) = 3 x"
assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag
assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag
def test_get_latex_use_math_symbols() -> None:
def f(alpha):
return alpha
latex_without_flag = r"f(\mathrm{alpha}) = \mathrm{alpha}"
latex_with_flag = r"f(\alpha) = \alpha"
assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, use_math_symbols=False) == latex_without_flag
assert generate_latex.get_latex(f, use_math_symbols=True) == latex_with_flag
def test_get_latex_use_signature() -> None:
def f(x):
return x
latex_without_flag = "x"
latex_with_flag = r"f(x) = x"
assert generate_latex.get_latex(f) == latex_with_flag
assert generate_latex.get_latex(f, use_signature=False) == latex_without_flag
assert generate_latex.get_latex(f, use_signature=True) == latex_with_flag
def test_get_latex_use_set_symbols() -> None:
def f(x, y):
return x & y
latex_without_flag = r"f(x, y) = x \mathbin{\&} y"
latex_with_flag = r"f(x, y) = x \cap y"
assert generate_latex.get_latex(f) == latex_without_flag
assert generate_latex.get_latex(f, use_set_symbols=False) == latex_without_flag
assert generate_latex.get_latex(f, use_set_symbols=True) == latex_with_flag
================================================
FILE: src/latexify/ipython_wrappers.py
================================================
"""Wrapper objects for IPython to display output."""
from __future__ import annotations
import abc
from typing import Any, Callable, cast
from latexify import exceptions, generate_latex
class LatexifiedRepr(metaclass=abc.ABCMeta):
"""Object with LaTeX representation."""
_fn: Callable[..., Any]
def __init__(self, fn: Callable[..., Any], **kwargs) -> None:
self._fn = fn
@property
def __doc__(self) -> str | None:
return self._fn.__doc__
@__doc__.setter
def __doc__(self, val: str | None) -> None:
self._fn.__doc__ = val
@property
def __name__(self) -> str:
return self._fn.__name__
@__name__.setter
def __name__(self, val: str) -> None:
self._fn.__name__ = val
# After Python 3.7
# @final
def __call__(self, *args) -> Any:
return self._fn(*args)
@abc.abstractmethod
def __str__(self) -> str: ...
@abc.abstractmethod
def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None:
"""IPython hook to display HTML visualization."""
...
@abc.abstractmethod
def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None:
"""IPython hook to display LaTeX visualization."""
...
class LatexifiedAlgorithm(LatexifiedRepr):
"""Algorithm with latex representation."""
_latex: str | None
_error: str | None
_ipython_latex: str | None
_ipython_error: str | None
def __init__(self, fn: Callable[..., Any], **kwargs) -> None:
super().__init__(fn)
try:
self._latex = generate_latex.get_latex(
fn, style=generate_latex.Style.ALGORITHMIC, **kwargs
)
self._error = None
except exceptions.LatexifyError as e:
self._latex = None
self._error = f"{type(e).__name__}: {str(e)}"
try:
self._ipython_latex = generate_latex.get_latex(
fn, style=generate_latex.Style.IPYTHON_ALGORITHMIC, **kwargs
)
self._ipython_error = None
except exceptions.LatexifyError as e:
self._ipython_latex = None
self._ipython_error = f"{type(e).__name__}: {str(e)}"
def __str__(self) -> str:
return self._latex if self._latex is not None else cast(str, self._error)
def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None:
"""IPython hook to display HTML visualization."""
return (
'' + self._ipython_error + ""
if self._ipython_error is not None
else None
)
def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None:
"""IPython hook to display LaTeX visualization."""
return (
f"$ {self._ipython_latex} $"
if self._ipython_latex is not None
else self._ipython_error
)
class LatexifiedFunction(LatexifiedRepr):
"""Function with latex representation."""
_latex: str | None
_error: str | None
def __init__(self, fn: Callable[..., Any], **kwargs) -> None:
super().__init__(fn, **kwargs)
try:
self._latex = self._latex = generate_latex.get_latex(
fn, style=generate_latex.Style.FUNCTION, **kwargs
)
self._error = None
except exceptions.LatexifyError as e:
self._latex = None
self._error = f"{type(e).__name__}: {str(e)}"
def __str__(self) -> str:
return self._latex if self._latex is not None else cast(str, self._error)
def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None:
"""IPython hook to display HTML visualization."""
return (
'' + self._error + ""
if self._error is not None
else None
)
def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None:
"""IPython hook to display LaTeX visualization."""
return (
rf"$$ \displaystyle {self._latex} $$"
if self._latex is not None
else self._error
)
================================================
FILE: src/latexify/parser.py
================================================
"""Parsing utilities."""
from __future__ import annotations
import ast
import inspect
import textwrap
from collections.abc import Callable
from typing import Any
import dill # type: ignore[import]
from latexify import exceptions
def parse_function(fn: Callable[..., Any]) -> ast.Module:
"""Parses given function.
Args:
fn: Target function.
Returns:
AST tree representing `fn`.
"""
try:
source = inspect.getsource(fn)
except Exception:
# Maybe running on console.
source = dill.source.getsource(fn)
# Remove extra indentation so that ast.parse runs correctly.
source = textwrap.dedent(source)
tree = ast.parse(source)
if not tree.body or not isinstance(tree.body[0], ast.FunctionDef):
raise exceptions.LatexifySyntaxError("Not a function.")
return tree
================================================
FILE: src/latexify/parser_test.py
================================================
"""Tests for latexify.parser."""
from __future__ import annotations
import ast
import pytest
from latexify import ast_utils, exceptions, parser, test_utils
def test_parse_function_with_posonlyargs() -> None:
def f(x):
return x
expected = ast.Module(
body=[
ast_utils.create_function_def(
name="f",
args=ast.arguments(
posonlyargs=[],
args=[ast.arg(arg="x")],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[],
),
body=[ast.Return(value=ast.Name(id="x", ctx=ast.Load()))],
decorator_list=[],
returns=None,
type_comment=None,
type_params=[],
lineno=1,
col_offset=0,
end_lineno=2,
end_col_offset=0,
)
],
type_ignores=[],
)
obtained = parser.parse_function(f)
test_utils.assert_ast_equal(obtained, expected)
def test_parse_function_with_lambda() -> None:
with pytest.raises(exceptions.LatexifySyntaxError, match=r"^Not a function\.$"):
parser.parse_function(lambda: ())
with pytest.raises(exceptions.LatexifySyntaxError, match=r"^Not a function\.$"):
x = lambda: () # noqa: E731
parser.parse_function(x)
================================================
FILE: src/latexify/test_utils.py
================================================
"""Test utilities."""
from __future__ import annotations
import ast
import functools
import sys
from collections.abc import Callable
from typing import cast
def require_at_least(
minor: int,
) -> Callable[[Callable[..., None]], Callable[..., None]]:
"""Require the minimum minor version of Python 3 to run the test.
Args:
minor: Minimum minor version (inclusive) that the test case supports.
Returns:
A decorator function to wrap the test case function.
"""
def decorator(fn: Callable[..., None]) -> Callable[..., None]:
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if sys.version_info.minor < minor:
return
fn(*args, **kwargs)
return wrapper
return decorator
def require_at_most(
minor: int,
) -> Callable[[Callable[..., None]], Callable[..., None]]:
"""Require the maximum minor version of Python 3 to run the test.
Args:
minor: Maximum minor version (inclusive) that the test case supports.
Returns:
A decorator function to wrap the test case function.
"""
def decorator(fn: Callable[..., None]) -> Callable[..., None]:
@functools.wraps(fn)
def wrapper(*args, **kwargs):
if sys.version_info.minor > minor:
return
fn(*args, **kwargs)
return wrapper
return decorator
def ast_equal(observed: ast.AST, expected: ast.AST) -> bool:
"""Checks the equality between two ASTs.
This function checks if `observed` contains at least the same subtree with
`expected`. If `observed` has some extra branches that `expected` does not cover,
it is ignored.
Args:
observed: An AST to check.
expected: The expected AST.
Returns:
True if observed and expected represent the same AST, False otherwise.
"""
ignore_keys = {"lineno", "col_offset", "end_lineno", "end_col_offset", "kind"}
if sys.version_info.minor <= 12:
ignore_keys.add("type_params")
try:
assert type(observed) is type(expected)
for k, ve in vars(expected).items():
if k in ignore_keys:
continue
vo = getattr(observed, k) # May cause AttributeError.
if isinstance(ve, ast.AST):
assert ast_equal(cast(ast.AST, vo), ve)
elif isinstance(ve, list):
vo = cast(list, vo)
assert len(vo) == len(ve)
assert all(
ast_equal(cast(ast.AST, co), cast(ast.AST, ce))
for co, ce in zip(vo, ve)
)
else:
assert type(vo) is type(ve)
assert vo == ve
except (AssertionError, AttributeError):
raise # raise to debug easier.
return True
def assert_ast_equal(observed: ast.AST, expected: ast.AST) -> None:
"""Asserts the equality between two ASTs.
Args:
observed: An AST to compare.
expected: Another AST.
Raises:
AssertionError: observed and expected represent different ASTs.
"""
assert ast_equal(
observed, expected
), f"""\
AST does not match.
observed={ast.dump(observed, indent=4)}
expected={ast.dump(expected, indent=4)}
"""
================================================
FILE: src/latexify/transformers/__init__.py
================================================
"""Package latexify.transformers."""
from latexify.transformers.assignment_reducer import AssignmentReducer
from latexify.transformers.aug_assign_replacer import AugAssignReplacer
from latexify.transformers.docstring_remover import DocstringRemover
from latexify.transformers.function_expander import FunctionExpander
from latexify.transformers.identifier_replacer import IdentifierReplacer
from latexify.transformers.prefix_trimmer import PrefixTrimmer
__all__ = [
"AssignmentReducer",
"AugAssignReplacer",
"DocstringRemover",
"FunctionExpander",
"IdentifierReplacer",
"PrefixTrimmer",
]
================================================
FILE: src/latexify/transformers/assignment_reducer.py
================================================
"""NodeTransformer to reduce assigned expressions."""
from __future__ import annotations
import ast
from typing import Any
from latexify import ast_utils, exceptions
class AssignmentReducer(ast.NodeTransformer):
"""NodeTransformer to reduce assigned expressions.
This class replaces a functions with multiple assignments to a function with only
single return.
Example:
def f(x):
y = 2 + x
z = 3 * y
return 4 + z
AssignmentReducer modifies the function above to below:
def f(x):
return 4 + 3 * (2 + x)
"""
_assignments: dict[str, ast.expr] | None = None
# TODO(odashi):
# Currently, this function does not care much about some expressions, e.g.,
# comprehensions or lambdas, which introduces inner scopes.
# It may cause some mistakes in the resulting AST.
def visit_FunctionDef(self, node: ast.FunctionDef) -> Any:
"""Visit a FunctionDef node."""
# Push stack
parent_assignments = self._assignments
self._assignments = {}
for child in node.body[:-1]:
if not isinstance(child, ast.Assign):
raise exceptions.LatexifyNotSupportedError(
"AssignmentReducer supports only Assign nodes, "
f"but got: {type(child).__name__}"
)
value = self.visit(child.value)
for target in child.targets:
if not isinstance(target, ast.Name):
raise exceptions.LatexifyNotSupportedError(
"AssignmentReducer does not recognize list/tuple "
"decomposition."
)
self._assignments[target.id] = value
return_original = node.body[-1]
if not isinstance(return_original, (ast.Return, ast.If)):
raise exceptions.LatexifySyntaxError(
f"Unsupported last statement: {type(return_original).__name__}"
)
return_transformed = self.visit(return_original)
# Pop stack
self._assignments = parent_assignments
type_params = getattr(node, "type_params", [])
return ast_utils.create_function_def(
name=node.name,
args=node.args,
body=[return_transformed],
decorator_list=node.decorator_list,
returns=node.returns,
type_params=type_params,
)
def visit_Name(self, node: ast.Name) -> Any:
"""Visit a Name node."""
if self._assignments is not None:
return self._assignments.get(node.id, node)
return node
================================================
FILE: src/latexify/transformers/assignment_reducer_test.py
================================================
"""Tests for latexify.transformers.assignment_reducer."""
from __future__ import annotations
import ast
from latexify import ast_utils, parser, test_utils
from latexify.transformers.assignment_reducer import AssignmentReducer
def _make_ast(body: list[ast.stmt]) -> ast.Module:
"""Helper function to generate an AST for f(x).
Args:
body: The function body.
Returns:
Generated AST.
"""
return ast.Module(
body=[
ast_utils.create_function_def(
name="f",
args=ast.arguments(
args=[ast.arg(arg="x")],
kwonlyargs=[],
kw_defaults=[],
defaults=[],
posonlyargs=[],
),
body=body,
decorator_list=[],
type_params=[],
)
],
type_ignores=[],
)
def test_unchanged() -> None:
def f(x):
return x
expected = _make_ast(
[
ast.Return(value=ast.Name(id="x", ctx=ast.Load())),
]
)
transformed = AssignmentReducer().visit(parser.parse_function(f))
test_utils.assert_ast_equal(transformed, expected)
def test_constant() -> None:
def f(x):
y = 0
return y
expected = _make_ast(
[
ast.Return(value=ast_utils.make_constant(0)),
]
)
transformed = AssignmentReducer().visit(parser.parse_function(f))
test_utils.assert_ast_equal(transformed, expected)
def test_nested() -> None:
def f(x):
y = 2 * x
return y
expected = _make_ast(
[
ast.Return(
value=ast.BinOp(
left=ast_utils.make_constant(2),
op=ast.Mult(),
right=ast.Name(id="x", ctx=ast.Load()),
)
)
]
)
transformed = AssignmentReducer().visit(parser.parse_function(f))
test_utils.assert_ast_equal(transformed, expected)
def test_nested2() -> None:
def f(x):
y = 2 * x
z = 3 + y
return z
expected = _make_ast(
[
ast.Return(
value=ast.BinOp(
left=ast_utils.make_constant(3),
op=ast.Add(),
right=ast.BinOp(
left=ast_utils.make_constant(2),
op=ast.Mult(),
right=ast.Name(id="x", ctx=ast.Load()),
),
)
)
]
)
transformed = AssignmentReducer().visit(parser.parse_function(f))
test_utils.assert_ast_equal(transformed, expected)
def test_overwrite() -> None:
def f(x):
y = 2 * x
y = 3 + x
return y
expected = _make_ast(
[
ast.Return(
value=ast.BinOp(
left=ast_utils.make_constant(3),
op=ast.Add(),
right=ast.Name(id="x", ctx=ast.Load()),
)
)
]
)
transformed = AssignmentReducer().visit(parser.parse_function(f))
test_utils.assert_ast_equal(transformed, expected)
================================================
FILE: src/latexify/transformers/aug_assign_replacer.py
================================================
"""Transformer to replace AugAssign to Assign."""
from __future__ import annotations
import ast
class AugAssignReplacer(ast.NodeTransformer):
"""NodeTransformer to replace AugAssign to corresponding Assign.
AugAssign(target, op, value) => Assign([target], BinOp(target, op, value))
"""
def visit_AugAssign(self, node: ast.AugAssign) -> ast.Assign:
left_args = {**vars(node.target), "ctx": ast.Load()}
left = type(node.target)(**left_args)
return ast.Assign(
targets=[node.target], value=ast.BinOp(left, node.op, node.value)
)
================================================
FILE: src/latexify/transformers/aug_assign_replacer_test.py
================================================
"""Tests for latexify.transformers.aug_assign_replacer."""
import ast
from latexify import test_utils
from latexify.transformers.aug_assign_replacer import AugAssignReplacer
def test_replace() -> None:
tree = ast.AugAssign(
target=ast.Name(id="x", ctx=ast.Store()),
op=ast.Add(),
value=ast.Name(id="y", ctx=ast.Load()),
)
expected = ast.Assign(
targets=[ast.Name(id="x", ctx=ast.Store())],
value=ast.BinOp(
left=ast.Name(id="x", ctx=ast.Load()),
op=ast.Add(),
right=ast.Name(id="y", ctx=ast.Load()),
),
)
transformed = AugAssignReplacer().visit(tree)
test_utils.assert_ast_equal(transformed, expected)
================================================
FILE: src/latexify/transformers/docstring_remover.py
================================================
"""Transformer to remove all docstrings."""
from __future__ import annotations
import ast
from typing import Union
from latexify import ast_utils
class DocstringRemover(ast.NodeTransformer):
"""NodeTransformer to remove all docstrings.
Docstrings here are detected as Expr nodes with a single string constant.
"""
def visit_Expr(self, node: ast.Expr) -> Union[ast.Expr, None]:
if ast_utils.is_str(node.value):
return None
return node
================================================
FILE: src/latexify/transformers/docstring_remover_test.py
================================================
"""Tests for latexify.transformers.docstring_remover."""
import ast
from latexify import ast_utils, parser, test_utils
from latexify.transformers.docstring_remover import DocstringRemover
def test_remove_docstrings() -> None:
def f():
"""Test docstring."""
x = 42
f() # This Expr should not be removed.
"""This string constant should also be removed."""
return x
tree = parser.parse_function(f).body[0]
assert isinstance(tree, ast.FunctionDef)
expected = ast_utils.create_function_def(
name="f",
body=[
ast.Assign(
targets=[ast.Name(id="x", ctx=ast.Store())],
value=ast_utils.make_constant(42),
),
ast.Expr(
value=ast.Call(
func=ast.Name(id="f", ctx=ast.Load()), args=[], keywords=[]
)
),
ast.Return(value=ast.Name(id="x", ctx=ast.Load())),
],
args=ast.arguments(
posonlyargs=[],
args=[],
vararg=None,
kwonlyargs=[],
kw_defaults=[],
kwarg=None,
defaults=[],
),
decorator_list=[],
type_params=[],
)
transformed = DocstringRemover().visit(tree)
test_utils.assert_ast_equal(transformed, expected)
================================================
FILE: src/latexify/transformers/function_expander.py
================================================
from __future__ import annotations
import ast
import functools
from collections.abc import Callable
from latexify import ast_utils, exceptions
# TODO(ZibingZhang): handle mutually recursive function expansions
class FunctionExpander(ast.NodeTransformer):
"""NodeTransformer to expand functions.
This class replaces function calls with an expanded form.
Example:
def f(x, y):
return hypot(x, y)
FunctionExpander({"hypot"}) will modify the AST of the function above to below:
def f(x, y):
return sqrt(x**2, y**2)
"""
def __init__(self, functions: set[str]) -> None:
self._functions = functions
def visit_Call(self, node: ast.Call) -> ast.AST:
"""Visit a Call node."""
func_name = ast_utils.extract_function_name_or_none(node)
if (
func_name is not None
and func_name in self._functions
and func_name in _FUNCTION_EXPANDERS
):
return _FUNCTION_EXPANDERS[func_name](self, node)
kwargs = {
"func": self.visit(node.func),
"args": [self.visit(x) for x in node.args],
}
if hasattr(node, "keywords"):
kwargs["keywords"] = [
ast.keyword(arg=x.arg, value=self.visit(x.value)) for x in node.keywords
]
return ast.Call(**kwargs)
def _atan2_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
_check_num_args(node, 2)
return ast.Call(
func=ast.Name(id="atan", ctx=ast.Load()),
args=[
ast.BinOp(
left=function_expander.visit(node.args[0]),
op=ast.Div(),
right=function_expander.visit(node.args[1]),
)
],
keywords=[],
)
def _exp_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
_check_num_args(node, 1)
return ast.BinOp(
left=ast.Name(id="e", ctx=ast.Load()),
op=ast.Pow(),
right=function_expander.visit(node.args[0]),
)
def _exp2_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
_check_num_args(node, 1)
return ast.BinOp(
left=ast_utils.make_constant(2),
op=ast.Pow(),
right=function_expander.visit(node.args[0]),
)
def _expm1_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
_check_num_args(node, 1)
return ast.BinOp(
left=function_expander.visit(
ast.Call(
func=ast.Name(id="exp", ctx=ast.Load()),
args=[node.args[0]],
keywords=[],
)
),
op=ast.Sub(),
right=ast_utils.make_constant(1),
)
def _hypot_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
if not node.args:
return ast_utils.make_constant(0)
args = [
ast.BinOp(
left=function_expander.visit(arg),
op=ast.Pow(),
right=ast_utils.make_constant(2),
)
for arg in node.args
]
args_reduced = functools.reduce(
lambda a, b: ast.BinOp(left=a, op=ast.Add(), right=b), args
)
return ast.Call(
func=ast.Name(id="sqrt", ctx=ast.Load()),
args=[args_reduced],
keywords=[],
)
def _log1p_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
_check_num_args(node, 1)
return ast.Call(
func=ast.Name(id="log", ctx=ast.Load()),
args=[
ast.BinOp(
left=ast_utils.make_constant(1),
op=ast.Add(),
right=function_expander.visit(node.args[0]),
)
],
keywords=[],
)
def _pow_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST:
_check_num_args(node, 2)
return ast.BinOp(
left=function_expander.visit(node.args[0]),
op=ast.Pow(),
right=function_expander.visit(node.args[1]),
)
def _check_num_args(node: ast.Call, nargs: int) -> None:
if len(node.args) != nargs:
fn_name = ast_utils.extract_function_name_or_none(node)
raise exceptions.LatexifySyntaxError(
f"Incorrect number of arguments for {fn_name}."
f" expected: {nargs}, but got {len(node.args)}"
)
_FUNCTION_EXPANDERS: dict[str, Callable[[FunctionExpander, ast.Call], ast.AST]] = {
"atan2": _atan2_expander,
"exp": _exp_expander,
"exp2": _exp2_expander,
"expm1": _expm1_expander,
"hypot": _hypot_expander,
"log1p": _log1p_expander,
"pow": _pow_expander,
}
================================================
FILE: src/latexify/transformers/function_expander_test.py
================================================
"""Tests for latexify.transformers.function_expander."""
from __future__ import annotations
import ast
from latexify import ast_utils, test_utils
from latexify.transformers.function_expander import FunctionExpander
def test_preserve_keywords() -> None:
tree = ast.Call(
func=ast_utils.make_name("f"),
args=[ast_utils.make_name("x")],
keywords=[ast.keyword(arg="y", value=ast_utils.make_constant(0))],
)
expected = ast.Call(
func=ast_utils.make_name("f"),
args=[ast_utils.make_name("x")],
keywords=[ast.keyword(arg="y", value=ast_utils.make_constant(0))],
)
transformed = FunctionExpander(set()).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_exp() -> None:
tree = ast.Call(
func=ast_utils.make_name("exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
expected = ast.BinOp(
left=ast_utils.make_name("e"),
op=ast.Pow(),
right=ast_utils.make_name("x"),
)
transformed = FunctionExpander({"exp"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_exp_unchanged() -> None:
tree = ast.Call(
func=ast_utils.make_name("exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
expected = ast.Call(
func=ast_utils.make_name("exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
transformed = FunctionExpander(set()).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_exp_with_attribute() -> None:
tree = ast.Call(
func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
expected = ast.BinOp(
left=ast_utils.make_name("e"),
op=ast.Pow(),
right=ast_utils.make_name("x"),
)
transformed2 = FunctionExpander({"exp"}).visit(tree)
test_utils.assert_ast_equal(transformed2, expected)
def test_exp_unchanged_with_attribute() -> None:
tree = ast.Call(
func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
expected = ast.Call(
func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
transformed = FunctionExpander(set()).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_exp_nested1() -> None:
tree = ast.Call(
func=ast_utils.make_name("exp"),
args=[
ast.Call(
func=ast_utils.make_name("exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
],
keywords=[],
)
expected = ast.BinOp(
left=ast_utils.make_name("e"),
op=ast.Pow(),
right=ast.BinOp(
left=ast_utils.make_name("e"),
op=ast.Pow(),
right=ast_utils.make_name("x"),
),
)
transformed = FunctionExpander({"exp"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_exp_nested2() -> None:
tree = ast.Call(
func=ast_utils.make_name("f"),
args=[
ast.Call(
func=ast_utils.make_name("exp"),
args=[ast_utils.make_name("x")],
keywords=[],
)
],
keywords=[],
)
expected = ast.Call(
func=ast_utils.make_name("f"),
args=[
ast.BinOp(
left=ast_utils.make_name("e"),
op=ast.Pow(),
right=ast_utils.make_name("x"),
)
],
keywords=[],
)
transformed = FunctionExpander({"exp"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_atan2() -> None:
tree = ast.Call(
func=ast_utils.make_name("atan2"),
args=[ast_utils.make_name("y"), ast_utils.make_name("x")],
keywords=[],
)
expected = ast.Call(
func=ast_utils.make_name("atan"),
args=[
ast.BinOp(
left=ast_utils.make_name("y"),
op=ast.Div(),
right=ast_utils.make_name("x"),
)
],
keywords=[],
)
transformed = FunctionExpander({"atan2"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_exp2() -> None:
tree = ast.Call(
func=ast_utils.make_name("exp2"),
args=[ast_utils.make_name("x")],
keywords=[],
)
expected = ast.BinOp(
left=ast_utils.make_constant(2),
op=ast.Pow(),
right=ast_utils.make_name("x"),
)
transformed = FunctionExpander({"exp2"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_expm1() -> None:
tree = ast.Call(
func=ast_utils.make_name("expm1"),
args=[ast_utils.make_name("x")],
keywords=[],
)
expected = ast.BinOp(
left=ast.Call(
func=ast_utils.make_name("exp"),
args=[ast_utils.make_name("x")],
keywords=[],
),
op=ast.Sub(),
right=ast_utils.make_constant(1),
)
transformed = FunctionExpander({"expm1"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_hypot() -> None:
tree = ast.Call(
func=ast_utils.make_name("hypot"),
args=[ast_utils.make_name("x"), ast_utils.make_name("y")],
keywords=[],
)
expected = ast.Call(
func=ast_utils.make_name("sqrt"),
args=[
ast.BinOp(
left=ast.BinOp(
left=ast_utils.make_name("x"),
op=ast.Pow(),
right=ast_utils.make_constant(2),
),
op=ast.Add(),
right=ast.BinOp(
left=ast_utils.make_name("y"),
op=ast.Pow(),
right=ast_utils.make_constant(2),
),
)
],
keywords=[],
)
transformed = FunctionExpander({"hypot"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_hypot_no_args() -> None:
tree = ast.Call(func=ast_utils.make_name("hypot"), args=[], keywords=[])
expected = ast_utils.make_constant(0)
transformed = FunctionExpander({"hypot"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_log1p() -> None:
tree = ast.Call(
func=ast_utils.make_name("log1p"),
args=[ast_utils.make_name("x")],
keywords=[],
)
expected = ast.Call(
func=ast_utils.make_name("log"),
args=[
ast.BinOp(
left=ast_utils.make_constant(1),
op=ast.Add(),
right=ast_utils.make_name("x"),
)
],
keywords=[],
)
transformed = FunctionExpander({"log1p"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
def test_pow() -> None:
tree = ast.Call(
func=ast_utils.make_name("pow"),
args=[ast_utils.make_name("x"), ast_utils.make_name("y")],
keywords=[],
)
expected = ast.BinOp(
left=ast_utils.make_name("x"),
op=ast.Pow(),
right=ast_utils.make_name("y"),
)
transformed = FunctionExpander({"pow"}).visit(tree)
test_utils.assert_ast_equal(transformed, expected)
================================================
FILE: src/latexify/transformers/identifier_replacer.py
================================================
"""Transformer to replace user symbols."""
from __future__ import annotations
import ast
import keyword
from typing import cast
from latexify import ast_utils
class IdentifierReplacer(ast.NodeTransformer):
"""NodeTransformer to replace identifier names.
This class defines a rule to replace identifiers in AST with specified names.
Example:
def foo(bar):
return baz
IdentifierReplacer({"foo": "x", "bar": "y", "baz": "z"}) will modify the AST of
the function above to below:
def x(y):
return z
"""
def __init__(self, mapping: dict[str, str]):
"""Initializer.
Args:
mapping: User defined mapping of names. Keys are the original names of the
identifiers, and corresponding values are the replacements.
Both keys and values have to represent valid Python identifiers:
^[A-Za-z_][A-Za-z0-9_]*$
"""
self._mapping = mapping
for k, v in self._mapping.items():
if not str.isidentifier(k) or keyword.iskeyword(k):
raise ValueError(f"'{k}' is not an identifier name.")
if not str.isidentifier(v) or keyword.iskeyword(v):
raise ValueError(f"'{v}' is not an identifier name.")
def _replace_args(self, args: list[ast.arg]) -> list[ast.arg]:
"""Helper function to replace arg names."""
return [ast.arg(arg=self._mapping.get(a.arg, a.arg)) for a in args]
def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef:
"""Visit a FunctionDef node."""
visited = cast(ast.FunctionDef, super().generic_visit(node))
args = ast.arguments(
posonlyargs=self._replace_args(visited.args.posonlyargs),
args=self._replace_args(visited.args.args),
vararg=visited.args.vararg,
kwonlyargs=self._replace_args(visited.args.kwonlyargs),
kw_defaults=visited.args.kw_defaults,
kwarg=visited.args.kwarg,
defaults=visited.args.defaults,
)
type_params = getattr(visited, "type_params", [])
return ast_utils.create_function_def(
name=self._mapping.get(visited.name, visited.name),
args=args,
body=visited.body,
decorator_list=visited.decorator_list,
returns=visited.returns,
type_params=type_params,
)
def visit_Name(self, node: ast.Name) -> ast.Name:
"""Visit a Name node."""
return ast.Name(
id=self._mapping.get(node.id, node.id),
ctx=node.ctx,
)
================================================
FILE: src/latexify/transformers/identifier_replacer_test.py
================================================
"""Tests for latexify.transformer.identifier_replacer."""
from __future__ import annotations
import ast
import pytest
from latexify import ast_utils, test_utils
from latexify.transformers.identifier_replacer import IdentifierReplacer
def test_invalid_mapping() -> None:
with pytest.raises(ValueError, match=r"'123' is not an identifier name."):
IdentifierReplacer({"123": "foo"})
with pytest.raises(ValueError, match=r"'456' is not an identifier name."):
IdentifierReplacer({"foo": "456"})
with pytest.raises(ValueError, match=r"'def' is not an identifier name."):
IdentifierReplacer({"foo": "def"})
def test_name_replaced() -> None:
source = ast.Name(id="foo", ctx=ast.Load())
expected = ast.Name(id="bar", ctx=ast.Load())
transformed = IdentifierReplacer({"foo": "bar"}).visit(source)
test_utils.assert_ast_equal(transformed, expected)
def test_name_not_replaced() -> None:
source = ast.Name(id="foo", ctx=ast.Load())
expected = ast.Name(id="foo", ctx=ast.Load())
transformed = IdentifierReplacer({"fo": "bar"}).visit(source)
test_utils.assert_ast_equal(transformed, expected)
transformed = IdentifierReplacer({"fooo": "bar"}).visit(source)
test_utils.assert_ast_equal(transformed, expected)
def test_functiondef_with_posonlyargs() -> None:
# Subtree of:
# @d
# def f(x=a, /, y=b, *, z=c):
# pass
source = ast_utils.create_function_def(
name="f",
args=ast.arguments(
posonlyargs=[ast.arg(arg="x")],
args=[ast.arg(arg="y")],
kwonlyargs=[ast.arg(arg="z")],
kw_defaults=[ast.Name(id="c", ctx=ast.Load())],
defaults=[
ast.Name(id="a", ctx=ast.Load()),
ast.Name(id="b", ctx=ast.Load()),
],
),
body=[ast.Pass()],
decorator_list=[ast.Name(id="d", ctx=ast.Load())],
returns=None,
type_comment=None,
type_params=[],
)
expected = ast_utils.create_function_def(
name="F",
args=ast.arguments(
posonlyargs=[ast.arg(arg="X")],
args=[ast.arg(arg="Y")],
kwonlyargs=[ast.arg(arg="Z")],
kw_defaults=[ast.Name(id="C", ctx=ast.Load())],
defaults=[
ast.Name(id="A", ctx=ast.Load()),
ast.Name(id="B", ctx=ast.Load()),
],
),
body=[ast.Pass()],
decorator_list=[ast.Name(id="D", ctx=ast.Load())],
returns=None,
type_comment=None,
type_params=[],
)
mapping = {x: x.upper() for x in "abcdfxyz"}
transformed = IdentifierReplacer(mapping).visit(source)
test_utils.assert_ast_equal(transformed, expected)
def test_expr() -> None:
# Subtree of:
# (x + y) * z
source = ast.BinOp(
left=ast.BinOp(
left=ast.Name(id="x", ctx=ast.Load()),
op=ast.Add(),
right=ast.Name(id="y", ctx=ast.Load()),
),
op=ast.Mult(),
right=ast.Name(id="z", ctx=ast.Load()),
)
expected = ast.BinOp(
left=ast.BinOp(
left=ast.Name(id="X", ctx=ast.Load()),
op=ast.Add(),
right=ast.Name(id="Y", ctx=ast.Load()),
),
op=ast.Mult(),
right=ast.Name(id="Z", ctx=ast.Load()),
)
mapping = {x: x.upper() for x in "xyz"}
transformed = IdentifierReplacer(mapping).visit(source)
test_utils.assert_ast_equal(transformed, expected)
================================================
FILE: src/latexify/transformers/prefix_trimmer.py
================================================
"""NodeTransformer to trim unnecessary prefixes."""
from __future__ import annotations
import ast
import re
from latexify import ast_utils
_PREFIX_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$")
class PrefixTrimmer(ast.NodeTransformer):
"""NodeTransformer to trim unnecessary prefixes.
This class investigates all Attribute subtrees, and replace them if the prefix of
the attribute matches the given set of prefixes.
Prefix is searched in the manner of leftmost longest matching.
Example:
def f(x):
return math.sqrt(x)
PrefixTrimmer({"math"}) will modify the AST of the function above to below:
def f(x):
return sqrt(x)
"""
_prefixes: list[tuple[str, ...]]
def __init__(self, prefixes: set[str]) -> None:
"""Initializer.
Args:
prefixes: Set of prefixes to be trimmed. Nested prefix is allowed too.
Each value must follow one of the following formats:
- A Python identifier, e.g., "math"
- Python identifiers joined by periods, e.g., "numpy.random"
"""
for p in prefixes:
if not _PREFIX_PATTERN.match(p):
raise ValueError(f"Invalid prefix: {p}")
self._prefixes = [tuple(p.split(".")) for p in prefixes]
def _get_prefix(self, node: ast.expr) -> tuple[str, ...] | None:
"""Helper to obtain nested prefix.
Args:
node: Node to investigate.
Returns:
The prefix tuple, or None if the node has unsupported syntax.
"""
if isinstance(node, ast.Name):
return (node.id,)
if isinstance(node, ast.Attribute):
parent = self._get_prefix(node.value)
return parent + (node.attr,) if parent is not None else None
return None
def _make_attribute(self, prefix: tuple[str, ...], name: str) -> ast.expr:
"""Helper to generate a new Attribute or Name node.
Args:
prefix: List of prefixes.
name: Attribute name.
Returns:
Name node if prefix == (), (possibly nested) Attribute node otherwise.
"""
if not prefix:
return ast_utils.make_name(name)
parent = self._make_attribute(prefix[:-1], prefix[-1])
return ast_utils.make_attribute(parent, name)
def visit_Attribute(self, node: ast.Attribute) -> ast.expr:
"""Visit an Attribute node."""
prefix = self._get_prefix(node.value)
if prefix is None:
return node
# Performs leftmost longest match.
# NOTE(odashi):
# This implementation is very naive, but would work efficiently as long as the
# number of patterns is small.
matched_length = 0
for p in self._prefixes:
length = min(len(p), len(prefix))
if prefix[:length] == p and length > matched_length:
matched_length = length
return self._make_attribute(prefix[matched_length:], node.attr)
================================================
FILE: src/latexify/transformers/prefix_trimmer_test.py
================================================
"""Tests for latexify.transformers.prefix_trimmer."""
from __future__ import annotations
import ast
import pytest
from latexify import ast_utils, test_utils
from latexify.transformers import prefix_trimmer
# For convenience
make_name = ast_utils.make_name
make_attr = ast_utils.make_attribute
PrefixTrimmer = prefix_trimmer.PrefixTrimmer
@pytest.mark.parametrize(
"prefix", [".x", "x.", "1", "1x", "x.1", "x.1x", "x.x.1", "x.x.1x" "x..x", "x.x..x"]
)
def test_invalid_prefix(prefix: str) -> None:
with pytest.raises(ValueError, match=rf"^Invalid prefix: {prefix}$"):
PrefixTrimmer({prefix})
@pytest.mark.parametrize(
"prefixes,expected",
[
(set(), make_name("foo")),
({"foo"}, make_name("foo")),
({"bar"}, make_name("foo")),
({"foo.bar"}, make_name("foo")),
({"foo", "bar"}, make_name("foo")),
({"foo", "foo.bar"}, make_name("foo")),
],
)
def test_name(prefixes: set[str], expected: ast.expr) -> None:
source = make_name("foo")
transformed = PrefixTrimmer(prefixes).visit(source)
test_utils.assert_ast_equal(transformed, expected)
@pytest.mark.parametrize(
"prefixes,expected",
[
(set(), make_attr(make_name("foo"), "bar")),
({"fo"}, make_attr(make_name("foo"), "bar")),
({"foo"}, make_name("bar")),
({"bar"}, make_attr(make_name("foo"), "bar")),
({"baz"}, make_attr(make_name("foo"), "bar")),
({"foo.bar"}, make_attr(make_name("foo"), "bar")),
({"foo", "bar"}, make_name("bar")),
({"foo", "foo.bar"}, make_name("bar")),
],
)
def test_attr_1(prefixes: set[str], expected: ast.expr) -> None:
source = make_attr(make_name("foo"), "bar")
transformed = PrefixTrimmer(prefixes).visit(source)
test_utils.assert_ast_equal(transformed, expected)
@pytest.mark.parametrize(
"prefixes,expected",
[
(set(), make_attr(make_attr(make_name("foo"), "bar"), "baz")),
({"fo"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")),
({"foo"}, make_attr(make_name("bar"), "baz")),
({"bar"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")),
({"baz"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")),
({"foo.ba"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")),
({"foo.bar"}, make_name("baz")),
({"foo.bar.baz"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")),
({"foo", "bar"}, make_attr(make_name("bar"), "baz")),
({"foo", "foo.bar"}, make_name("baz")),
],
)
def test_attr_2(prefixes: set[str], expected: ast.expr) -> None:
source = make_attr(make_attr(make_name("foo"), "bar"), "baz")
transformed = PrefixTrimmer(prefixes).visit(source)
test_utils.assert_ast_equal(transformed, expected)