Repository: google/latexify_py Branch: main Commit: 54dc86971963 Files: 66 Total size: 275.1 KB Directory structure: gitextract__eefs9oe/ ├── .github/ │ ├── ISSUE_TEMPLATE/ │ │ ├── bug_report.md │ │ └── feature_request.md │ ├── pull_request_template.md │ └── workflows/ │ ├── ci.yml │ └── release.yml ├── .gitignore ├── CODEOWNERS ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── checks.sh ├── docs/ │ ├── getting_started.md │ ├── index.md │ └── parameters.md ├── examples/ │ └── latexify_examples.ipynb ├── pyproject.toml └── src/ ├── integration_tests/ │ ├── __init__.py │ ├── algorithmic_style_test.py │ ├── function_expansion_test.py │ ├── integration_utils.py │ └── regression_test.py └── latexify/ ├── __init__.py ├── _version.py ├── analyzers.py ├── analyzers_test.py ├── ast_utils.py ├── ast_utils_test.py ├── codegen/ │ ├── __init__.py │ ├── algorithmic_codegen.py │ ├── algorithmic_codegen_test.py │ ├── codegen_utils.py │ ├── codegen_utils_test.py │ ├── expression_codegen.py │ ├── expression_codegen_test.py │ ├── expression_rules.py │ ├── expression_rules_test.py │ ├── function_codegen.py │ ├── function_codegen_match_test.py │ ├── function_codegen_test.py │ ├── identifier_converter.py │ ├── identifier_converter_test.py │ ├── latex.py │ └── latex_test.py ├── config.py ├── exceptions.py ├── frontend.py ├── frontend_test.py ├── generate_latex.py ├── generate_latex_test.py ├── ipython_wrappers.py ├── parser.py ├── parser_test.py ├── test_utils.py └── transformers/ ├── __init__.py ├── assignment_reducer.py ├── assignment_reducer_test.py ├── aug_assign_replacer.py ├── aug_assign_replacer_test.py ├── docstring_remover.py ├── docstring_remover_test.py ├── function_expander.py ├── function_expander_test.py ├── identifier_replacer.py ├── identifier_replacer_test.py ├── prefix_trimmer.py └── prefix_trimmer_test.py ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/ISSUE_TEMPLATE/bug_report.md ================================================ --- name: Bug report about: Create a report to help us improve title: '' labels: triage assignees: odashi --- ## Environment If you used latexify on the browser, fill the following items. * Browser: * Frontend: If you used latexify in your own environment, fill at least the following items. Feel free to add other items if you think they are useful. * OS: * Python: * Package manager: * Latexify version: ## Description Describe the details of the issue. Feel free to insert screenshots if they are useful. ## Reproduction Describe how to reproduce the issue by other people. ## Expected behavior Describe how latexify should behave in the case above. ================================================ FILE: .github/ISSUE_TEMPLATE/feature_request.md ================================================ --- name: Feature request about: Suggest an idea for this project title: '' labels: feature assignees: odashi --- ## Description Is your feature request related to a problem? Please describe it. A clear and concise description is recommended to proceed the discussion efficiently. ## Ideas of the solution If you have an idea about the solution you'd like, describe details about it. ## Alternative ideas If you have other ideas that are already considered, describe them as well. These ideas may also help us to make reasonable decisions. ## Additional context Add any other context or screenshots about the feature request here. ================================================ FILE: .github/pull_request_template.md ================================================ # Overview # Details # References # Blocked by ================================================ FILE: .github/workflows/ci.yml ================================================ name: Continuous integration on: push: branches: - main pull_request: branches: ["**"] jobs: unit-tests: runs-on: ubuntu-latest strategy: matrix: python-version: ["3.9", "3.10", "3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v4 with: python-version: ${{ matrix.python-version }} - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install -e ".[dev]" - name: Test run: python -m pytest src black: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install black - name: Check run: python -m black -v --check src flake8: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install pyproject-flake8 - name: Check run: pflake8 -v src isort: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install isort - name: Check run: python -m isort --check src mypy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v3 - name: Set up Python uses: actions/setup-python@v4 with: python-version: "3.11" - name: Install dependencies run: | python -m pip install --upgrade pip python -m pip install '.[mypy]' - name: Check run: python -m mypy src ================================================ FILE: .github/workflows/release.yml ================================================ name: Release workflow on: push: tags: - "v[0123456789].*" jobs: release: runs-on: ubuntu-latest steps: - name: checkout uses: actions/checkout@v3 - name: setup python uses: actions/setup-python@v2 with: python-version: "3.10" - name: build run: | python -m pip install --upgrade build hatch python -m hatch version "${GITHUB_REF_NAME}" python -m build - name: publish uses: pypa/gh-action-pypi-publish@release/v1 with: password: ${{ secrets.PYPI_API_TOKEN }} ================================================ FILE: .gitignore ================================================ # Temporary files .swp temp tmp # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ share/python-wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .nox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover *.py,cover .hypothesis/ .pytest_cache/ cover/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 db.sqlite3-journal # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder .pybuilder/ target/ # Jupyter Notebook .ipynb_checkpoints # IPython profile_default/ ipython_config.py # pyenv # For a library or package, you might want to ignore these files since the code is # intended to run in multiple environments; otherwise, check them in: .python-version # pipenv # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. # However, in case of collaboration, if having platform-specific dependencies or dependencies # having no cross-platform support, pipenv may install dependencies that don't work, or not # install all needed dependencies. #Pipfile.lock # PEP 582; used by e.g. github.com/David-OConnor/pyflow __pypackages__/ # Celery stuff celerybeat-schedule celerybeat.pid # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # PyCharm project settings .idea/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ .dmypy.json dmypy.json # Pyre type checker .pyre/ # pytype static type analyzer .pytype/ # Cython debug symbols cython_debug/ ================================================ FILE: CODEOWNERS ================================================ * @odashi ================================================ FILE: CONTRIBUTING.md ================================================ # How to Contribute We'd love to accept your patches and contributions to this project. There are just a few small guidelines you need to follow. ## Contributor License Agreement Contributions to this project must be accompanied by a Contributor License Agreement (CLA). You (or your employer) retain the copyright to your contribution; this simply gives us permission to use and redistribute your contributions as part of the project. Head over to to see your current agreements on file or to sign a new one. You generally only need to submit a CLA once, so if you've already submitted one (even if it was for a different project), you probably don't need to do it again. ## Code reviews All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests. ## Community Guidelines This project follows [Google's Open Source Community Guidelines](https://opensource.google/conduct/). ## Coding style This project follows [Tensorflow's style](https://www.tensorflow.org/community/contribute/code_style). ================================================ FILE: LICENSE ================================================ Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ================================================ FILE: README.md ================================================ # latexify [![Python](https://img.shields.io/pypi/pyversions/latexify-py.svg)](https://pypi.org/project/latexify-py/) [![PyPI Latest Release](https://img.shields.io/pypi/v/latexify-py.svg)](https://pypi.org/project/latexify-py/) [![License](https://img.shields.io/pypi/l/latexify-py.svg)](https://github.com/google/latexify_py/blob/main/LICENSE) [![Downloads](https://pepy.tech/badge/latexify-py/month)](https://pepy.tech/project/latexify-py) [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black) [![Imports: isort](https://img.shields.io/badge/%20imports-isort-%231674b1?style=flat&labelColor=ef8336)](https://pycqa.github.io/isort/) `latexify` is a Python package to compile a fragment of Python source code to a corresponding $\LaTeX$ expression: ![Example of latexify usage](https://raw.githubusercontent.com/google/latexify_py/main/example.jpg) `latexify` provides the following functionalities: * Libraries to compile Python source code or AST to $\LaTeX$. * IPython classes to pretty-print compiled functions. ## FAQs 1. *Which Python versions are supported?* Syntaxes on **Pythons 3.9 to 3.13** are officially supported, or will be supported. 2. *Which technique is used?* `latexify` is implemented as a rule-based system on the official `ast` package. 3. *Are "AI" techniques adopted?* `latexify` is based on traditional parsing techniques. If the "AI" meant some techniques around machine learning, the answer is no. ## Getting started See the [example notebook](https://github.com/google/latexify_py/blob/main/examples/latexify_examples.ipynb), which provides several use-cases of this library. You can also try the above notebook on [Google Colaboratory](https://colab.research.google.com/github/google/latexify_py/blob/main/examples/latexify_examples.ipynb). See also the official [documentation](https://github.com/google/latexify_py/blob/main/docs/index.md) for more details. ## How to Contribute To contribute to this project, please refer [CONTRIBUTING.md](https://github.com/google/latexify_py/blob/develop/CONTRIBUTING.md). ## Disclaimer This software is currently hosted on , but not officially supported by Google. If you have any issues and/or questions about this software, please visit the [issue tracker](https://github.com/google/latexify_py/issues) or contact the [main maintainer](https://github.com/odashi). ## License This software adopts the [Apache License 2.0](https://github.com/google/latexify_py/blob/develop/LICENSE). ================================================ FILE: checks.sh ================================================ #!/bin/bash set -eoux pipefail python -m pytest src -vv python -m black --check src python -m pflake8 src python -m isort --check src python -m mypy src ================================================ FILE: docs/getting_started.md ================================================ # Getting started This document describes how to use `latexify` with your Python code. ## Installation `latexify` depends on only Python libraries at this point. You can simply install `latexify` via `pip`: ```shell $ pip install latexify-py ``` Note that you have to install `latexify-py` rather than `latexify`. ## Using `latexify` in Jupyter `latexify.function` decorator function wraps your functions to pretty-print them as corresponding LaTeX formulas. Jupyter recognizes this wrapper and try to print LaTeX instead of the original function. The following snippet: ```python @latexify.function def solve(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) solve ``` will print the following formula to the output: $$ \mathrm{solve}(a, b, c) = \frac{-b + \sqrt{b^2 - 4ac}}{2a} $$ Invoking wrapped functions work transparently as the original function. ```python solve(1, 2, 1) ``` ``` -1.0 ``` Applying `str` to the wrapped function returns the underlying LaTeX source. ```python print(solve) ``` ``` f(n) = \\frac{-b + \\sqrt{b^{2} - 4ac}}{2a} ``` `latexify.expression` works similarly to `latexify.function`, but it prints the function without its signature: ```python @latexify.expression def solve(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) solve ``` $$ \frac{-b + \sqrt{b^2 - 4ac}}{2a} $$ ## Obtaining LaTeX expression directly You can also use `latexify.get_latex`, which takes a function and directly returns the LaTeX expression corresponding to the given function. The same parameters with `latexify.function` can be applied to `latexify.get_latex` as well. ```python def solve(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) latexify.get_latex(solve) ``` ``` f(n) = \\frac{-b + \\sqrt{b^{2} - 4ac}}{2a} ``` ================================================ FILE: docs/index.md ================================================ # `latexify` documentation ## Index * [Getting started](getting_started.md) * [Parameters](parameters.md) ## External resources * [Examples on Google Colaboratory](https://colab.research.google.com/drive/1MuiawKpVIZ12MWwyYuzZHmbKThdM5wNJ?usp=sharing) ================================================ FILE: docs/parameters.md ================================================ # `latexify` parameters This document describes the list of parameters to control the behavior of `latexify`. ## `identifiers: dict[str, str]` Key-value pair of identifiers to replace. ```python identifiers = { "my_function": "f", "my_inner_function": "g", "my_argument": "x", } @latexify.function(identifiers=identifiers) def my_function(my_argument): return my_inner_function(my_argument) my_function ``` $$f(x) = \mathrm{g}\left(x\right)$$ ## `reduce_assignments: bool` Whether to compose all variables defined before the `return` statement. The current version of `latexify` recognizes only the assignment statements. Analyzing functions with other control flows may raise errors. ```python @latexify.function(reduce_assignments=True) def f(a, b, c): discriminant = b**2 - 4 * a * c numerator = -b + math.sqrt(discriminant) denominator = 2 * a return numerator / denominator f ``` $$f(a, b, c) = \frac{-b + \sqrt{b^{2} - 4 a c}}{2 a}$$ ## `use_math_symbols: bool` Whether to automatically convert variables with symbol names into LaTeX symbols or not. ```python @latexify.function(use_math_symbols=True) def greek(alpha, beta, gamma, Omega): return alpha * beta + math.gamma(gamma) + Omega greek ``` $$\mathrm{greek}({\alpha}, {\beta}, {\gamma}, {\Omega}) = {\alpha} {\beta} + \Gamma\left({{\gamma}}\right) + {\Omega}$$ ## `use_set_symbols: bool` Whether to use binary operators for set operations or not. ```python @latexify.function(use_set_symbols=True) def f(x, y): return x & y, x | y, x - y, x ^ y, x < y, x <= y, x > y, x >= y f ``` $$f(x, y) = \left( x \cap y\space,\space x \cup y\space,\space x \setminus y\space,\space x \mathbin{\triangle} y\space,\space {x \subset y}\space,\space {x \subseteq y}\space,\space {x \supset y}\space,\space {x \supseteq y}\right)$$ ## `use_signature: bool` Whether to output the function signature or not. The default value of this flag depends on the frontend function. `True` is used in `latexify.function`, while `False` is used in `latexify.expression`. ```python @latexify.function(use_signature=False) def f(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) f ``` $$\frac{-b + \sqrt{b^{2} - 4 a c}}{2 a}$$ ================================================ FILE: examples/latexify_examples.ipynb ================================================ { "cells": [ { "cell_type": "markdown", "metadata": { "id": "W5mNJI3Bnl6n" }, "source": [ "# `latexify` examples\n", "\n", "This notebook provides several examples to use `latexify`.\n", "\n", "See also the\n", "[official documentation](https://github.com/google/latexify_py/blob/documentation/docs/index.md)\n", "for more details.\n", "\n", "If you have any questions, please ask it in the\n", "[issue tracker](https://github.com/google/latexify_py/issues)." ] }, { "cell_type": "markdown", "metadata": { "id": "fWCVgcRHoLd8" }, "source": [ "## Install `latexify`" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "4IPGyu2dFH6T", "outputId": "471cab8d-3069-4a27-f3ff-67ba177ec58d" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting latexify-py\n", " Downloading latexify_py-0.4.2-py3-none-any.whl (38 kB)\n", "Collecting dill>=0.3.2 (from latexify-py)\n", " Downloading dill-0.3.7-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.3/115.3 kB\u001b[0m \u001b[31m5.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hInstalling collected packages: dill, latexify-py\n", "Successfully installed dill-0.3.7 latexify-py-0.4.2\n" ] } ], "source": [ "# Restart the runtime before running the examples below.\n", "%pip install latexify-py\n" ] }, { "cell_type": "markdown", "metadata": { "id": "-Mzq4_dNoSmc" }, "source": [ "## Import `latexify` into your code" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "hViDMhyMFNCO", "outputId": "b46edb25-5952-4cff-da1e-d65e7e3caad0" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'0.4.2'" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import math # Optional\n", "import numpy as np # Optional\n", "import latexify\n", "\n", "latexify.__version__\n" ] }, { "cell_type": "markdown", "metadata": { "id": "4QJ6I2s7odX1" }, "source": [ "## Examples" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "NvbEYSwXFaeE", "outputId": "5d0ca2a4-a285-4053-9cc4-3776746443be" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "-1.0\n", "\\mathrm{solve}(a, b, c) = \\frac{-b + \\sqrt{ b^{2} - 4 a c }}{2 a}\n" ] } ], "source": [ "@latexify.function\n", "def solve(a, b, c):\n", " return (-b + math.sqrt(b**2 - 4*a*c)) / (2*a)\n", "\n", "print(solve(1, 4, 3)) # Invoking the function works as expected.\n", "print(solve) # Printing the function shows the underlying LaTeX source.\n", "solve # Displays the expression.\n", "\n", "# Writes the underlying LaTeX source into a file.\n", "with open(\"compiled.tex\", \"w\") as fp:\n", " print(solve, file=fp)\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 56 }, "id": "wS7BhtPgjSak", "outputId": "76a8547c-e6b5-458d-aeb2-f9df2f35f7c7" }, "outputs": [ { "data": { "text/latex": [ "$$ \\displaystyle \\frac{-b + \\sqrt{ b^{2} - 4 a c }}{2 a} $$" ], "text/plain": [ "" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# latexify.expression works similarly, but does not output the signature.\n", "@latexify.expression\n", "def solve(a, b, c):\n", " return (-b + math.sqrt(b**2 - 4*a*c)) / (2*a)\n", "\n", "solve\n" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 35 }, "id": "G73dnoqqjg4A", "outputId": "b9f53cf8-4a34-452c-8d9b-946ddd0998df" }, "outputs": [ { "data": { "application/vnd.google.colaboratory.intrinsic+json": { "type": "string" }, "text/plain": [ "'\\\\mathrm{solve}(a, b, c) = \\\\frac{-b + \\\\sqrt{ b^{2} - 4 a c }}{2 a}'" ] }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# latexify.get_latex obtains the underlying LaTeX expression directly.\n", "def solve(a, b, c):\n", " return (-b + math.sqrt(b**2 - 4*a*c)) / (2*a)\n", "\n", "latexify.get_latex(solve)\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 58 }, "id": "8bYSWIngGF8E", "outputId": "669e070d-2718-49cb-a2fe-0defe0286b27" }, "outputs": [ { "data": { "text/latex": [ "$$ \\displaystyle \\mathrm{sinc}(x) = \\left\\{ \\begin{array}{ll} 1, & \\mathrm{if} \\ x = 0 \\\\ \\frac{\\sin x}{x}, & \\mathrm{otherwise} \\end{array} \\right. $$" ], "text/plain": [ "" ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "@latexify.function\n", "def sinc(x):\n", " if x == 0:\n", " return 1\n", " else:\n", " return math.sin(x) / x\n", "\n", "sinc\n" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 78 }, "id": "h1i4BjdgHjxl", "outputId": "e448ff37-4753-4090-b2b1-1ef21b279b34" }, "outputs": [ { "data": { "text/latex": [ "$$ \\displaystyle \\mathrm{fib}(x) = \\left\\{ \\begin{array}{ll} 0, & \\mathrm{if} \\ x = 0 \\\\ 1, & \\mathrm{if} \\ x = 1 \\\\ \\mathrm{fib} \\mathopen{}\\left( x - 1 \\mathclose{}\\right) + \\mathrm{fib} \\mathopen{}\\left( x - 2 \\mathclose{}\\right), & \\mathrm{otherwise} \\end{array} \\right. $$" ], "text/plain": [ "" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Elif or nested else-if are unrolled.\n", "@latexify.function\n", "def fib(x):\n", " if x == 0:\n", " return 0\n", " elif x == 1:\n", " return 1\n", " else:\n", " return fib(x-1) + fib(x-2)\n", "\n", "fib\n" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 39 }, "id": "-JhJMAXM7j-X", "outputId": "a47dcd59-2ff9-4aa1-935d-7c789b39057e" }, "outputs": [ { "data": { "text/latex": [ "$$ \\displaystyle \\mathrm{greek}(\\alpha, \\beta, \\gamma, \\Omega) = \\alpha \\beta + \\Gamma \\mathopen{}\\left( \\gamma \\mathclose{}\\right) + \\Omega $$" ], "text/plain": [ "" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Some math symbols are converted automatically.\n", "@latexify.function(use_math_symbols=True)\n", "def greek(alpha, beta, gamma, Omega):\n", " return alpha * beta + math.gamma(gamma) + Omega\n", "\n", "greek\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 39 }, "id": "ySyNPS0y4tzu", "outputId": "2d95b5ce-a9b8-42b1-eb55-dc8bd0097d69" }, "outputs": [ { "data": { "text/latex": [ "$$ \\displaystyle f(x) = g \\mathopen{}\\left( x \\mathclose{}\\right) $$" ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Function names, arguments, variables can be replaced.\n", "identifiers = {\n", " \"my_function\": \"f\",\n", " \"my_inner_function\": \"g\",\n", " \"my_argument\": \"x\",\n", "}\n", "\n", "@latexify.function(identifiers=identifiers)\n", "def my_function(my_argument):\n", " return my_inner_function(my_argument)\n", "\n", "my_function\n" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 56 }, "id": "TyacQaDM4Ei7", "outputId": "8e971bbd-2c74-45d2-d0fa-7f46569b10a6" }, "outputs": [ { "data": { "text/latex": [ "$$ \\displaystyle f(a, b, c) = \\frac{-b + \\sqrt{ b^{2} - 4 a c }}{2 a} $$" ], "text/plain": [ "" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Assignments can be reduced into one expression.\n", "@latexify.function(reduce_assignments=True)\n", "def f(a, b, c):\n", " discriminant = b**2 - 4 * a * c\n", " numerator = -b + math.sqrt(discriminant)\n", " denominator = 2 * a\n", " return numerator / denominator\n", "\n", "f\n" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 78 }, "id": "oD8MFS2WE-2U", "outputId": "f9fad1bd-b7eb-41cc-8743-ec0d80cca8bc" }, "outputs": [ { "data": { "text/latex": [ "$$ \\displaystyle \\mathrm{transform}(x, y, a, b, \\theta, s, t) = \\begin{bmatrix} 1 & 0 & s \\\\ 0 & 1 & t \\\\ 0 & 0 & 1 \\end{bmatrix} \\cdot \\begin{bmatrix} \\cos \\theta & -\\sin \\theta & 0 \\\\ \\sin \\theta & \\cos \\theta & 0 \\\\ 0 & 0 & 1 \\end{bmatrix} \\cdot \\begin{bmatrix} a & 0 & 0 \\\\ 0 & b & 0 \\\\ 0 & 0 & 1 \\end{bmatrix} \\cdot \\begin{bmatrix} x \\\\ y \\\\ 1 \\end{bmatrix} $$" ], "text/plain": [ "" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Matrix support.\n", "@latexify.function(reduce_assignments=True, use_math_symbols=True)\n", "def transform(x, y, a, b, theta, s, t):\n", " cos_t = math.cos(theta)\n", " sin_t = math.sin(theta)\n", " scale = np.array([[a, 0, 0], [0, b, 0], [0, 0, 1]])\n", " rotate = np.array([[cos_t, -sin_t, 0], [sin_t, cos_t, 0], [0, 0, 1]])\n", " move = np.array([[1, 0, s], [0, 1, t], [0, 0, 1]])\n", " return move @ rotate @ scale @ np.array([[x], [y], [1]])\n", "\n", "transform\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 241 }, "id": "81OlPVWyGfWN", "outputId": "48660400-a812-41e2-91ea-23e49ea20c7f" }, "outputs": [ { "data": { "text/latex": [ "$ \\begin{array}{l} \\mathbf{function} \\ \\mathrm{fib}(x) \\\\ \\hspace{1em} \\mathbf{if} \\ x = 0 \\\\ \\hspace{2em} \\mathbf{return} \\ 0 \\\\ \\hspace{1em} \\mathbf{else} \\\\ \\hspace{2em} \\mathbf{if} \\ x = 1 \\\\ \\hspace{3em} \\mathbf{return} \\ 1 \\\\ \\hspace{2em} \\mathbf{else} \\\\ \\hspace{3em} \\mathbf{return} \\ \\mathrm{fib} \\mathopen{}\\left( x - 1 \\mathclose{}\\right) + \\mathrm{fib} \\mathopen{}\\left( x - 2 \\mathclose{}\\right) \\\\ \\hspace{2em} \\mathbf{end \\ if} \\\\ \\hspace{1em} \\mathbf{end \\ if} \\\\ \\mathbf{end \\ function} \\end{array} $" ], "text/plain": [ "" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# latexify.algorithmic generates an algorithmic environment instead of an equation.\n", "@latexify.algorithmic\n", "def fib(x):\n", " if x == 0:\n", " return 0\n", " elif x == 1:\n", " return 1\n", " else:\n", " return fib(x-1) + fib(x-2)\n", "\n", "fib\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 261 }, "id": "kbw_1txkGfnX", "outputId": "fdc58207-1c06-4d88-e249-b0b011bd98c0" }, "outputs": [ { "data": { "text/latex": [ "$ \\begin{array}{l} \\mathbf{function} \\ \\mathrm{collatz}(x) \\\\ \\hspace{1em} n \\gets 0 \\\\ \\hspace{1em} \\mathbf{while} \\ x > 1 \\\\ \\hspace{2em} n \\gets n + 1 \\\\ \\hspace{2em} \\mathbf{if} \\ x \\mathbin{\\%} 2 = 0 \\\\ \\hspace{3em} x \\gets \\left\\lfloor\\frac{x}{2}\\right\\rfloor \\\\ \\hspace{2em} \\mathbf{else} \\\\ \\hspace{3em} x \\gets 3 x + 1 \\\\ \\hspace{2em} \\mathbf{end \\ if} \\\\ \\hspace{1em} \\mathbf{end \\ while} \\\\ \\hspace{1em} \\mathbf{return} \\ n \\\\ \\mathbf{end \\ function} \\end{array} $" ], "text/plain": [ "" ] }, "execution_count": 13, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Another example: latexify.algorithmic supports usual control flows.\n", "@latexify.algorithmic\n", "def collatz(x):\n", " n = 0\n", " while x > 1:\n", " n = n + 1\n", " if x % 2 == 0:\n", " x = x // 2\n", " else:\n", " x = 3 * x + 1\n", " return n\n", "\n", "collatz\n" ] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 0 } ================================================ FILE: pyproject.toml ================================================ [build-system] requires = [ "hatchling", ] build-backend = "hatchling.build" [project] name = "latexify-py" description = "Generates LaTeX math description from Python functions." readme = "README.md" requires-python = ">=3.9, <3.14" license = {text = "Apache Software License 2.0"} authors = [ {name = "Yusuke Oda", email = "odashi@inspiredco.ai"} ] keywords = [ "equation", "latex", "math", "mathematics", "tex", ] classifiers = [ "Framework :: IPython", "Framework :: Jupyter", "License :: OSI Approved :: Apache Software License", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Topic :: Scientific/Engineering :: Mathematics", "Topic :: Software Development :: Code Generators", "Topic :: Text Processing :: Markup :: LaTeX", ] dependencies = [ "dill>=0.3.2", ] dynamic = [ "version" ] [project.optional-dependencies] dev = [ "build>=0.8", "black>=24.3", "flake8>=6.0", "isort>=5.10", "mypy>=1.9", "notebook>=6.5.1", "pyproject-flake8>=6.0", "pytest>=7.1", "twine>=4.0", ] mypy = [ "mypy>=1.9", "pytest>=7.1", ] [project.urls] Homepage = "https://github.com/google/latexify_py" "Bug Tracker" = "https://github.com/google/latexify_py/issues" [tool.hatch.build] include = [ "*.py", ] exclude = [ "*_test.py", ] only-packages = true [tool.hatch.build.targets.wheel] packages = ["src/latexify"] [tool.hatch.version] path = "src/latexify/_version.py" [tool.flake8] max-line-length = 88 extend-ignore = "E203" [tool.isort] profile = "black" ================================================ FILE: src/integration_tests/__init__.py ================================================ """Package integration_tests.""" import pytest pytest.register_assert_rewrite("integration_tests.utils") ================================================ FILE: src/integration_tests/algorithmic_style_test.py ================================================ """End-to-end test cases of algorithmic style.""" from __future__ import annotations import textwrap from integration_tests import integration_utils def test_factorial() -> None: def fact(n): if n == 0: return 1 else: return n * fact(n - 1) latex = textwrap.dedent( r""" \begin{algorithmic} \Function{fact}{$n$} \If{$n = 0$} \State \Return $1$ \Else \State \Return $n \cdot \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right)$ \EndIf \EndFunction \end{algorithmic} """ # noqa: E501 ).strip() ipython_latex = ( r"\begin{array}{l}" r" \mathbf{function} \ \mathrm{fact}(n) \\" r" \hspace{1em} \mathbf{if} \ n = 0 \\" r" \hspace{2em} \mathbf{return} \ 1 \\" r" \hspace{1em} \mathbf{else} \\" r" \hspace{2em}" r" \mathbf{return} \ n \cdot" r" \mathrm{fact} \mathopen{}\left( n - 1 \mathclose{}\right) \\" r" \hspace{1em} \mathbf{end \ if} \\" r" \mathbf{end \ function}" r" \end{array}" ) integration_utils.check_algorithm(fact, latex, ipython_latex) def test_collatz() -> None: def collatz(n): iterations = 0 while n > 1: if n % 2 == 0: n = n // 2 else: n = 3 * n + 1 iterations = iterations + 1 return iterations latex = textwrap.dedent( r""" \begin{algorithmic} \Function{collatz}{$n$} \State $\mathrm{iterations} \gets 0$ \While{$n > 1$} \If{$n \mathbin{\%} 2 = 0$} \State $n \gets \left\lfloor\frac{n}{2}\right\rfloor$ \Else \State $n \gets 3 n + 1$ \EndIf \State $\mathrm{iterations} \gets \mathrm{iterations} + 1$ \EndWhile \State \Return $\mathrm{iterations}$ \EndFunction \end{algorithmic} """ ).strip() ipython_latex = ( r"\begin{array}{l}" r" \mathbf{function} \ \mathrm{collatz}(n) \\" r" \hspace{1em} \mathrm{iterations} \gets 0 \\" r" \hspace{1em} \mathbf{while} \ n > 1 \\" r" \hspace{2em} \mathbf{if} \ n \mathbin{\%} 2 = 0 \\" r" \hspace{3em} n \gets \left\lfloor\frac{n}{2}\right\rfloor \\" r" \hspace{2em} \mathbf{else} \\" r" \hspace{3em} n \gets 3 n + 1 \\" r" \hspace{2em} \mathbf{end \ if} \\" r" \hspace{2em}" r" \mathrm{iterations} \gets \mathrm{iterations} + 1 \\" r" \hspace{1em} \mathbf{end \ while} \\" r" \hspace{1em} \mathbf{return} \ \mathrm{iterations} \\" r" \mathbf{end \ function}" r" \end{array}" ) integration_utils.check_algorithm(collatz, latex, ipython_latex) ================================================ FILE: src/integration_tests/function_expansion_test.py ================================================ """End-to-end test cases of function expansion.""" from __future__ import annotations import math from integration_tests import integration_utils def test_atan2() -> None: def solve(x, y): return math.atan2(y, x) latex = ( r"\mathrm{solve}(x, y) =" r" \arctan \mathopen{}\left( \frac{y}{x} \mathclose{}\right)" ) integration_utils.check_function(solve, latex, expand_functions={"atan2"}) def test_atan2_nested() -> None: def solve(x, y): return math.atan2(math.exp(y), math.exp(x)) latex = ( r"\mathrm{solve}(x, y) =" r" \arctan \mathopen{}\left( \frac{e^{y}}{e^{x}} \mathclose{}\right)" ) integration_utils.check_function(solve, latex, expand_functions={"atan2", "exp"}) def test_exp() -> None: def solve(x): return math.exp(x) latex = r"\mathrm{solve}(x) = e^{x}" integration_utils.check_function(solve, latex, expand_functions={"exp"}) def test_exp_nested() -> None: def solve(x): return math.exp(math.exp(x)) latex = r"\mathrm{solve}(x) = e^{e^{x}}" integration_utils.check_function(solve, latex, expand_functions={"exp"}) def test_exp2() -> None: def solve(x): return math.exp2(x) latex = r"\mathrm{solve}(x) = 2^{x}" integration_utils.check_function(solve, latex, expand_functions={"exp2"}) def test_exp2_nested() -> None: def solve(x): return math.exp2(math.exp2(x)) latex = r"\mathrm{solve}(x) = 2^{2^{x}}" integration_utils.check_function(solve, latex, expand_functions={"exp2"}) def test_expm1() -> None: def solve(x): return math.expm1(x) latex = r"\mathrm{solve}(x) = \exp x - 1" integration_utils.check_function(solve, latex, expand_functions={"expm1"}) def test_expm1_nested() -> None: def solve(x, y, z): return math.expm1(math.pow(y, z)) latex = r"\mathrm{solve}(x, y, z) = e^{y^{z}} - 1" integration_utils.check_function( solve, latex, expand_functions={"expm1", "exp", "pow"} ) def test_hypot_without_attribute() -> None: from math import hypot def solve(x, y, z): return hypot(x, y, z) latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }" integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_hypot() -> None: def solve(x, y, z): return math.hypot(x, y, z) latex = r"\mathrm{solve}(x, y, z) = \sqrt{ x^{2} + y^{2} + z^{2} }" integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_hypot_nested() -> None: def solve(a, b, x, y): return math.hypot(math.hypot(a, b), x, y) latex = ( r"\mathrm{solve}(a, b, x, y) =" r" \sqrt{ \sqrt{ a^{2} + b^{2} }^{2} + x^{2} + y^{2} }" ) integration_utils.check_function(solve, latex, expand_functions={"hypot"}) def test_log1p() -> None: def solve(x): return math.log1p(x) latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + x \mathclose{}\right)" integration_utils.check_function(solve, latex, expand_functions={"log1p"}) def test_log1p_nested() -> None: def solve(x): return math.log1p(math.exp(x)) latex = r"\mathrm{solve}(x) = \log \mathopen{}\left( 1 + e^{x} \mathclose{}\right)" integration_utils.check_function(solve, latex, expand_functions={"log1p", "exp"}) def test_pow_nested() -> None: def solve(w, x, y, z): return math.pow(math.pow(w, x), math.pow(y, z)) latex = ( r"\mathrm{solve}(w, x, y, z) = " r"\mathopen{}\left( w^{x} \mathclose{}\right)^{y^{z}}" ) integration_utils.check_function(solve, latex, expand_functions={"pow"}) def test_pow() -> None: def solve(x, y): return math.pow(x, y) latex = r"\mathrm{solve}(x, y) = x^{y}" integration_utils.check_function(solve, latex, expand_functions={"pow"}) ================================================ FILE: src/integration_tests/integration_utils.py ================================================ """Utilities for integration tests.""" from __future__ import annotations from typing import Any, Callable from latexify import frontend def check_function( fn: Callable[..., Any], latex: str, **kwargs, ) -> None: """Helper to check if the obtained function has the expected LaTeX form. Args: fn: Function to check. latex: LaTeX form of `fn`. **kwargs: Arguments passed to `frontend.function`. """ # Checks the syntax: # @function # def fn(...): # ... if not kwargs: latexified = frontend.function(fn) assert str(latexified) == latex assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$" # Checks the syntax: # @function(**kwargs) # def fn(...): # ... latexified = frontend.function(**kwargs)(fn) assert str(latexified) == latex assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$" # Checks the syntax: # def fn(...): # ... # latexified = function(fn, **kwargs) latexified = frontend.function(fn, **kwargs) assert str(latexified) == latex assert latexified._repr_latex_() == rf"$$ \displaystyle {latex} $$" def check_algorithm( fn: Callable[..., Any], latex: str, ipython_latex: str, **kwargs, ) -> None: """Helper to check if the obtained function has the expected LaTeX form. Args: fn: Function to check. latex: LaTeX form of `fn`. ipython_latex: IPython LaTeX form of `fn` **kwargs: Arguments passed to `frontend.get_latex`. """ # Checks the syntax: # @algorithmic # def fn(...): # ... if not kwargs: latexified = frontend.algorithmic(fn) assert str(latexified) == latex assert latexified._repr_latex_() == f"$ {ipython_latex} $" # Checks the syntax: # @algorithmic(**kwargs) # def fn(...): # ... latexified = frontend.algorithmic(**kwargs)(fn) assert str(latexified) == latex assert latexified._repr_latex_() == f"$ {ipython_latex} $" # Checks the syntax: # def fn(...): # ... # latexified = algorithmic(fn, **kwargs) latexified = frontend.algorithmic(fn, **kwargs) assert str(latexified) == latex assert latexified._repr_latex_() == f"$ {ipython_latex} $" ================================================ FILE: src/integration_tests/regression_test.py ================================================ """End-to-end test cases of function.""" from __future__ import annotations import math from integration_tests import integration_utils def test_quadratic_solution() -> None: def solve(a, b, c): return (-b + math.sqrt(b**2 - 4 * a * c)) / (2 * a) latex = r"\mathrm{solve}(a, b, c) =" r" \frac{-b + \sqrt{ b^{2} - 4 a c }}{2 a}" integration_utils.check_function(solve, latex) def test_sinc() -> None: def sinc(x): if x == 0: return 1 else: return math.sin(x) / x latex = ( r"\mathrm{sinc}(x) =" r" \left\{ \begin{array}{ll}" r" 1, & \mathrm{if} \ x = 0 \\" r" \frac{\sin x}{x}, & \mathrm{otherwise}" r" \end{array} \right." ) integration_utils.check_function(sinc, latex) def test_x_times_beta() -> None: def xtimesbeta(x, beta): return x * beta latex_without_symbols = ( r"\mathrm{xtimesbeta}(x, \mathrm{beta}) = x \cdot \mathrm{beta}" ) integration_utils.check_function(xtimesbeta, latex_without_symbols) integration_utils.check_function( xtimesbeta, latex_without_symbols, use_math_symbols=False ) latex_with_symbols = r"\mathrm{xtimesbeta}(x, \beta) = x \beta" integration_utils.check_function( xtimesbeta, latex_with_symbols, use_math_symbols=True ) def test_sum_with_limit_1arg() -> None: def sum_with_limit(n): return sum(i**2 for i in range(n)) latex = ( r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n - 1}" r" \mathopen{}\left({i^{2}}\mathclose{}\right)" ) integration_utils.check_function(sum_with_limit, latex) def test_sum_with_limit_2args() -> None: def sum_with_limit(a, n): return sum(i**2 for i in range(a, n)) latex = ( r"\mathrm{sum\_with\_limit}(a, n) = \sum_{i = a}^{n - 1}" r" \mathopen{}\left({i^{2}}\mathclose{}\right)" ) integration_utils.check_function(sum_with_limit, latex) def test_sum_with_reducible_limit() -> None: def sum_with_limit(n): return sum(i for i in range(n + 1)) latex = ( r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n}" r" \mathopen{}\left({i}\mathclose{}\right)" ) integration_utils.check_function(sum_with_limit, latex) def test_sum_with_irreducible_limit() -> None: def sum_with_limit(n): return sum(i for i in range(n * 3)) latex = ( r"\mathrm{sum\_with\_limit}(n) = \sum_{i = 0}^{n \cdot 3 - 1}" r" \mathopen{}\left({i}\mathclose{}\right)" ) integration_utils.check_function(sum_with_limit, latex) def test_prod_with_limit_1arg() -> None: def prod_with_limit(n): return math.prod(i**2 for i in range(n)) latex = ( r"\mathrm{prod\_with\_limit}(n) =" r" \prod_{i = 0}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)" ) integration_utils.check_function(prod_with_limit, latex) def test_prod_with_limit_2args() -> None: def prod_with_limit(a, n): return math.prod(i**2 for i in range(a, n)) latex = ( r"\mathrm{prod\_with\_limit}(a, n) =" r" \prod_{i = a}^{n - 1} \mathopen{}\left({i^{2}}\mathclose{}\right)" ) integration_utils.check_function(prod_with_limit, latex) def test_prod_with_reducible_limits() -> None: def prod_with_limit(n): return math.prod(i for i in range(n - 1)) latex = ( r"\mathrm{prod\_with\_limit}(n) =" r" \prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)" ) integration_utils.check_function(prod_with_limit, latex) def test_prod_with_irreducible_limit() -> None: def prod_with_limit(n): return math.prod(i for i in range(n * 3)) latex = ( r"\mathrm{prod\_with\_limit}(n) = " r"\prod_{i = 0}^{n \cdot 3 - 1} \mathopen{}\left({i}\mathclose{}\right)" ) integration_utils.check_function(prod_with_limit, latex) def test_nested_function() -> None: def nested(x): return 3 * x integration_utils.check_function(nested, r"\mathrm{nested}(x) = 3 x") def test_double_nested_function() -> None: def nested(x): def inner(y): return x * y return inner integration_utils.check_function(nested(3), r"\mathrm{inner}(y) = x y") def test_reduce_assignments() -> None: def f(x): a = x + x return 3 * a integration_utils.check_function( f, r"\begin{array}{l} a = x + x \\ f(x) = 3 a \end{array}", ) integration_utils.check_function( f, r"f(x) = 3 \mathopen{}\left( x + x \mathclose{}\right)", reduce_assignments=True, ) def test_reduce_assignments_double() -> None: def f(x): a = x**2 b = a + a return 3 * b latex_without_option = ( r"\begin{array}{l}" r" a = x^{2} \\" r" b = a + a \\" r" f(x) = 3 b" r" \end{array}" ) integration_utils.check_function(f, latex_without_option) integration_utils.check_function(f, latex_without_option, reduce_assignments=False) integration_utils.check_function( f, r"f(x) = 3 \mathopen{}\left( x^{2} + x^{2} \mathclose{}\right)", reduce_assignments=True, ) def test_reduce_assignments_with_if() -> None: def sigmoid(x): p = 1 / (1 + math.exp(-x)) n = math.exp(x) / (math.exp(x) + 1) if x > 0: return p else: return n integration_utils.check_function( sigmoid, ( r"\mathrm{sigmoid}(x) = \left\{ \begin{array}{ll}" r" \frac{1}{1 + \exp \mathopen{}\left( -x \mathclose{}\right)}, &" r" \mathrm{if} \ x > 0 \\" r" \frac{\exp x}{\exp x + 1}, &" r" \mathrm{otherwise}" r" \end{array} \right." ), reduce_assignments=True, ) def test_sub_bracket() -> None: def solve(a, b): return ((a + b) - b) / (a - b) - (a + b) - (a - b) - (a * b) latex = ( r"\mathrm{solve}(a, b) =" r" \frac{a + b - b}{a - b} - \mathopen{}\left(" r" a + b \mathclose{}\right) - \mathopen{}\left(" r" a - b \mathclose{}\right) - a b" ) integration_utils.check_function(solve, latex) def test_docstring_allowed() -> None: def solve(x): """The identity function.""" return x latex = r"\mathrm{solve}(x) = x" integration_utils.check_function(solve, latex) def test_multiple_constants_allowed() -> None: def solve(x): """The identity function.""" 123 True return x latex = r"\mathrm{solve}(x) = x" integration_utils.check_function(solve, latex) ================================================ FILE: src/latexify/__init__.py ================================================ """Latexify root package.""" try: from latexify import _version __version__ = _version.__version__ except Exception: __version__ = "" from latexify import frontend, generate_latex Style = generate_latex.Style get_latex = generate_latex.get_latex algorithmic = frontend.algorithmic expression = frontend.expression function = frontend.function ================================================ FILE: src/latexify/_version.py ================================================ """Version specifier. DON'T TOUCH THIS FILE. This file is replaced during the release process. """ __version__ = "0.0.0a0" ================================================ FILE: src/latexify/analyzers.py ================================================ """Analyzer functions for specific subtrees.""" from __future__ import annotations import ast import dataclasses import sys from latexify import ast_utils, exceptions @dataclasses.dataclass(frozen=True, eq=False) class RangeInfo: """Information of the range function.""" # Argument subtrees. These arguments could be shallow copies of the original # subtree. start: ast.expr stop: ast.expr step: ast.expr # Integer representation of each argument, when it is possible. start_int: int | None stop_int: int | None step_int: int | None def analyze_range(node: ast.Call) -> RangeInfo: """Obtains RangeInfo from a Call subtree. Args: node: Subtree to be analyzed. Returns: RangeInfo extracted from `node`. Raises: LatexifySyntaxError: Analysis failed. """ if not ( isinstance(node.func, ast.Name) and node.func.id == "range" and 1 <= len(node.args) <= 3 ): raise exceptions.LatexifySyntaxError("Unsupported AST for analyze_range.") num_args = len(node.args) if num_args == 1: start = ast_utils.make_constant(0) stop = node.args[0] step = ast_utils.make_constant(1) else: start = node.args[0] stop = node.args[1] step = node.args[2] if num_args == 3 else ast_utils.make_constant(1) return RangeInfo( start=start, stop=stop, step=step, start_int=ast_utils.extract_int_or_none(start), stop_int=ast_utils.extract_int_or_none(stop), step_int=ast_utils.extract_int_or_none(step), ) def reduce_stop_parameter(node: ast.expr) -> ast.expr: """Adjusts the stop expression of the range. This function tries to convert the syntax as follows: * n + 1 --> n * n + 2 --> n + 1 * n - 1 --> n - 2 Args: node: The target expression. Returns: Converted expression. """ if not (isinstance(node, ast.BinOp) and isinstance(node.op, (ast.Add, ast.Sub))): return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1)) # Treatment for Python 3.7. rhs = ( ast.Constant(value=node.right.n) if sys.version_info.minor < 8 and isinstance(node.right, ast.Num) else node.right ) if not isinstance(rhs, ast.Constant): return ast.BinOp(left=node, op=ast.Sub(), right=ast_utils.make_constant(1)) shift = 1 if isinstance(node.op, ast.Add) else -1 return ( node.left if rhs.value == shift else ast.BinOp( left=node.left, op=node.op, right=ast_utils.make_constant(value=rhs.value - shift), ) ) ================================================ FILE: src/latexify/analyzers_test.py ================================================ """Tests for latexify.analyzers.""" from __future__ import annotations import ast import pytest from latexify import analyzers, ast_utils, exceptions, test_utils @pytest.mark.parametrize( "code,start,stop,step,start_int,stop_int,step_int", [ ( "range(x)", ast.Constant(value=0), ast.Name(id="x", ctx=ast.Load()), ast.Constant(value=1), 0, None, 1, ), ( "range(123)", ast.Constant(value=0), ast.Constant(value=123), ast.Constant(value=1), 0, 123, 1, ), ( "range(x, y)", ast.Name(id="x", ctx=ast.Load()), ast.Name(id="y", ctx=ast.Load()), ast.Constant(value=1), None, None, 1, ), ( "range(123, y)", ast.Constant(value=123), ast.Name(id="y", ctx=ast.Load()), ast.Constant(value=1), 123, None, 1, ), ( "range(x, 123)", ast.Name(id="x", ctx=ast.Load()), ast.Constant(value=123), ast.Constant(value=1), None, 123, 1, ), ( "range(x, y, z)", ast.Name(id="x", ctx=ast.Load()), ast.Name(id="y", ctx=ast.Load()), ast.Name(id="z", ctx=ast.Load()), None, None, None, ), ( "range(123, y, z)", ast.Constant(value=123), ast.Name(id="y", ctx=ast.Load()), ast.Name(id="z", ctx=ast.Load()), 123, None, None, ), ( "range(x, 123, z)", ast.Name(id="x", ctx=ast.Load()), ast.Constant(value=123), ast.Name(id="z", ctx=ast.Load()), None, 123, None, ), ( "range(x, y, 123)", ast.Name(id="x", ctx=ast.Load()), ast.Name(id="y", ctx=ast.Load()), ast.Constant(value=123), None, None, 123, ), ], ) def test_analyze_range( code: str, start: ast.expr, stop: ast.expr, step: ast.expr, start_int: int | None, stop_int: int | None, step_int: int | None, ) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Call) info = analyzers.analyze_range(node) test_utils.assert_ast_equal(observed=info.start, expected=start) test_utils.assert_ast_equal(observed=info.stop, expected=stop) if step is not None: test_utils.assert_ast_equal(observed=info.step, expected=step) else: assert info.step is None def check_int(observed: int | None, expected: int | None) -> None: if expected is not None: assert observed == expected else: assert observed is None check_int(observed=info.start_int, expected=start_int) check_int(observed=info.stop_int, expected=stop_int) check_int(observed=info.step_int, expected=step_int) @pytest.mark.parametrize( "code", [ # Not a direct call "__builtins__.range(x)", 'getattr(__builtins__, "range")(x)', # Unsupported functions "f(x)", "iter(range(x))", # Range with invalid arguments "range()", "range(x, y, z, w)", ], ) def test_analyze_range_invalid(code: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Call) with pytest.raises( exceptions.LatexifySyntaxError, match=r"^Unsupported AST for analyze_range\.$" ): analyzers.analyze_range(node) @pytest.mark.parametrize( "before,after", [ ("n + 1", "n"), ("n + 2", "n + 1"), ("n - (-1)", "n - (-1) - 1"), ("n - 1", "n - 2"), ("1 * 2", "1 * 2 - 1"), ], ) def test_reduce_stop_parameter(before: str, after: str) -> None: test_utils.assert_ast_equal( analyzers.reduce_stop_parameter(ast_utils.parse_expr(before)), ast_utils.parse_expr(after), ) ================================================ FILE: src/latexify/ast_utils.py ================================================ """Utilities to generate AST nodes.""" from __future__ import annotations import ast import sys from typing import Any def parse_expr(code: str) -> ast.expr: """Parses given Python expression. Args: code: Python expression to parse. Returns: ast.expr corresponding to `code`. """ return ast.parse(code, mode="eval").body def make_name(id: str) -> ast.Name: """Generates a new Name node. Args: id: Name of the node. Returns: Generated ast.Name. """ return ast.Name(id=id, ctx=ast.Load()) def make_attribute(value: ast.expr, attr: str): """Generates a new Attribute node. Args: value: Parent value. attr: Attribute name. Returns: Generated ast.Attribute. """ return ast.Attribute(value=value, attr=attr, ctx=ast.Load()) def make_constant(value: Any) -> ast.expr: """Generates a new Constant node. Args: value: Value of the node. Returns: Generated ast.Constant or its equivalent. Raises: ValueError: Unsupported value type. """ if ( value is None or value is ... or isinstance(value, (bool, int, float, complex, str, bytes)) ): return ast.Constant(value=value) raise ValueError(f"Unsupported type to generate Constant: {type(value).__name__}") def is_constant(node: ast.AST) -> bool: """Checks if the node is a constant. Args: node: The node to examine. Returns: True if the node is a constant, False otherwise. """ return isinstance(node, ast.Constant) def is_str(node: ast.AST) -> bool: """Checks if the node is a str constant. Args: node: The node to examine. Returns: True if the node is a str constant, False otherwise. """ if sys.version_info.minor < 8 and isinstance(node, ast.Str): return True return isinstance(node, ast.Constant) and isinstance(node.value, str) def extract_int_or_none(node: ast.expr) -> int | None: """Extracts int constant from the given Constant node. Args: node: ast.Constant or its equivalent representing an int value. Returns: Extracted int value, or None if extraction failed. """ if ( isinstance(node, ast.Constant) and isinstance(node.value, int) and not isinstance(node.value, bool) ): return node.value return None def extract_int(node: ast.expr) -> int: """Extracts int constant from the given Constant node. Args: node: ast.Constant or its equivalent representing an int value. Returns: Extracted int value. Raises: ValueError: Not a subtree containing an int value. """ value = extract_int_or_none(node) if value is None: raise ValueError(f"Unsupported node to extract int: {type(node).__name__}") return value def extract_function_name_or_none(node: ast.Call) -> str | None: """Extracts function name from the given Call node. Args: node: ast.Call. Returns: Extracted function name, or None if not found. """ if isinstance(node.func, ast.Name): return node.func.id if isinstance(node.func, ast.Attribute): return node.func.attr return None def create_function_def( name, args, body, decorator_list, returns=None, type_comment=None, type_params=None, lineno=None, col_offset=None, end_lineno=None, end_col_offset=None, ) -> ast.FunctionDef: """Creates a FunctionDef node. This function generates an `ast.FunctionDef` node, optionally removing the `type_params` keyword argument for Python versions below 3.12. Args: name: Name of the function. args: Arguments of the function. body: Body of the function. decorator_list: List of decorators. returns: Return type of the function. type_comment: Type comment of the function. type_params: Type parameters of the function. lineno: Line number of the function definition. col_offset: Column offset of the function definition. end_lineno: End line number of the function definition. end_col_offset: End column offset of the function definition. Returns: ast.FunctionDef: The generated FunctionDef node. """ if sys.version_info.minor < 12: return ast.FunctionDef( name=name, args=args, body=body, decorator_list=decorator_list, returns=returns, type_comment=type_comment, lineno=lineno, col_offset=col_offset, end_lineno=end_lineno, end_col_offset=end_col_offset, ) # type: ignore return ast.FunctionDef( name=name, args=args, body=body, decorator_list=decorator_list, returns=returns, type_comment=type_comment, type_params=type_params, lineno=lineno, col_offset=col_offset, end_lineno=end_lineno, end_col_offset=end_col_offset, ) # type: ignore ================================================ FILE: src/latexify/ast_utils_test.py ================================================ """Tests for latexify.ast_utils.""" from __future__ import annotations import ast import sys from typing import Any import pytest from latexify import ast_utils, test_utils def test_parse_expr() -> None: test_utils.assert_ast_equal( ast_utils.parse_expr("a + b"), ast.BinOp( left=ast_utils.make_name("a"), op=ast.Add(), right=ast_utils.make_name("b"), ), ) def test_make_name() -> None: test_utils.assert_ast_equal( ast_utils.make_name("foo"), ast.Name(id="foo", ctx=ast.Load()) ) def test_make_attribute() -> None: test_utils.assert_ast_equal( ast_utils.make_attribute(ast_utils.make_name("foo"), "bar"), ast.Attribute(ast.Name(id="foo", ctx=ast.Load()), attr="bar", ctx=ast.Load()), ) @pytest.mark.parametrize( "value,expected", [ (None, ast.Constant(value=None)), (False, ast.Constant(value=False)), (True, ast.Constant(value=True)), (..., ast.Constant(value=...)), (123, ast.Constant(value=123)), (4.5, ast.Constant(value=4.5)), (6 + 7j, ast.Constant(value=6 + 7j)), ("foo", ast.Constant(value="foo")), (b"bar", ast.Constant(value=b"bar")), ], ) def test_make_constant(value: Any, expected: ast.Constant) -> None: test_utils.assert_ast_equal( observed=ast_utils.make_constant(value), expected=expected, ) def test_make_constant_invalid() -> None: with pytest.raises(ValueError, match=r"^Unsupported type to generate"): ast_utils.make_constant(object()) @pytest.mark.parametrize( "value,expected", [ (ast.Constant(value="foo"), True), (ast.Expr(value=ast.Constant(value=123)), False), (ast.Global(names=["bar"]), False), ], ) def test_is_constant(value: ast.AST, expected: bool) -> None: assert ast_utils.is_constant(value) is expected @pytest.mark.parametrize( "value,expected", [ (ast.Constant(value=123), False), (ast.Constant(value="foo"), True), (ast.Expr(value=ast.Constant(value="foo")), False), (ast.Global(names=["foo"]), False), ], ) def test_is_str(value: ast.AST, expected: bool) -> None: assert ast_utils.is_str(value) is expected def test_extract_int_or_none() -> None: assert ast_utils.extract_int_or_none(ast_utils.make_constant(-123)) == -123 assert ast_utils.extract_int_or_none(ast_utils.make_constant(0)) == 0 assert ast_utils.extract_int_or_none(ast_utils.make_constant(123)) == 123 def test_extract_int_or_none_invalid() -> None: # Not a Constant node with int. assert ast_utils.extract_int_or_none(ast_utils.make_constant(None)) is None assert ast_utils.extract_int_or_none(ast_utils.make_constant(True)) is None assert ast_utils.extract_int_or_none(ast_utils.make_constant(...)) is None assert ast_utils.extract_int_or_none(ast_utils.make_constant(123.0)) is None assert ast_utils.extract_int_or_none(ast_utils.make_constant(4 + 5j)) is None assert ast_utils.extract_int_or_none(ast_utils.make_constant("123")) is None assert ast_utils.extract_int_or_none(ast_utils.make_constant(b"123")) is None def test_extract_int() -> None: assert ast_utils.extract_int(ast_utils.make_constant(-123)) == -123 assert ast_utils.extract_int(ast_utils.make_constant(0)) == 0 assert ast_utils.extract_int(ast_utils.make_constant(123)) == 123 def test_extract_int_invalid() -> None: # Not a Constant node with int. with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): ast_utils.extract_int(ast_utils.make_constant(None)) with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): ast_utils.extract_int(ast_utils.make_constant(True)) with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): ast_utils.extract_int(ast_utils.make_constant(...)) with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): ast_utils.extract_int(ast_utils.make_constant(123.0)) with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): ast_utils.extract_int(ast_utils.make_constant(4 + 5j)) with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): ast_utils.extract_int(ast_utils.make_constant("123")) with pytest.raises(ValueError, match=r"^Unsupported node to extract int"): ast_utils.extract_int(ast_utils.make_constant(b"123")) @pytest.mark.parametrize( "value,expected", [ ( ast.Call( func=ast.Name(id="hypot", ctx=ast.Load()), args=[], keywords=[], ), "hypot", ), ( ast.Call( func=ast.Attribute( value=ast.Name(id="math", ctx=ast.Load()), attr="hypot", ctx=ast.Load(), ), args=[], keywords=[], ), "hypot", ), ( ast.Call( func=ast.Call( func=ast.Name(id="foo", ctx=ast.Load()), args=[], keywords=[] ), args=[], keywords=[], ), None, ), ], ) def test_extract_function_name_or_none(value: ast.Call, expected: str | None) -> None: assert ast_utils.extract_function_name_or_none(value) == expected def test_create_function_def() -> None: expected_args = ast.arguments( posonlyargs=[], args=[ast.arg(arg="x")], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[], ) kwargs = { "name": "test_func", "args": expected_args, "body": [ast.Return(value=ast.Name(id="x", ctx=ast.Load()))], "decorator_list": [], "returns": None, "type_comment": None, "lineno": 1, "col_offset": 0, "end_lineno": 2, "end_col_offset": 0, } if sys.version_info.minor >= 12: kwargs["type_params"] = [] func_def = ast_utils.create_function_def(**kwargs) assert isinstance(func_def, ast.FunctionDef) assert func_def.name == "test_func" assert func_def.args.posonlyargs == expected_args.posonlyargs assert func_def.args.args == expected_args.args assert func_def.args.kwonlyargs == expected_args.kwonlyargs assert func_def.args.kw_defaults == expected_args.kw_defaults assert func_def.args.defaults == expected_args.defaults ================================================ FILE: src/latexify/codegen/__init__.py ================================================ """Package latexify.codegen.""" from latexify.codegen import algorithmic_codegen, expression_codegen, function_codegen AlgorithmicCodegen = algorithmic_codegen.AlgorithmicCodegen ExpressionCodegen = expression_codegen.ExpressionCodegen FunctionCodegen = function_codegen.FunctionCodegen IPythonAlgorithmicCodegen = algorithmic_codegen.IPythonAlgorithmicCodegen ================================================ FILE: src/latexify/codegen/algorithmic_codegen.py ================================================ """Codegen for single algorithms.""" from __future__ import annotations import ast import contextlib from collections.abc import Generator from latexify import exceptions from latexify.codegen import expression_codegen, identifier_converter class AlgorithmicCodegen(ast.NodeVisitor): """Codegen for single algorithms. This codegen works for Module with single FunctionDef node to generate a single LaTeX expression of the given algorithm. """ _SPACES_PER_INDENT = 4 _identifier_converter: identifier_converter.IdentifierConverter _indent_level: int def __init__( self, *, use_math_symbols: bool = False, use_set_symbols: bool = False, escape_underscores: bool = True, ) -> None: """Initializer. Args: use_math_symbols: Whether to convert identifiers with a math symbol surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). use_set_symbols: Whether to use set symbols or not. """ self._expression_codegen = expression_codegen.ExpressionCodegen( use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols, escape_underscores=escape_underscores, ) self._identifier_converter = identifier_converter.IdentifierConverter( use_math_symbols=use_math_symbols, use_mathrm=False, escape_underscores=escape_underscores, ) self._indent_level = 0 def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( f"Unsupported AST: {type(node).__name__}" ) def visit_Assign(self, node: ast.Assign) -> str: """Visit an Assign node.""" operands: list[str] = [ self._expression_codegen.visit(target) for target in node.targets ] operands.append(self._expression_codegen.visit(node.value)) operands_latex = r" \gets ".join(operands) return self._add_indent(rf"\State ${operands_latex}$") def visit_Expr(self, node: ast.Expr) -> str: """Visit an Expr node.""" return self._add_indent( rf"\State ${self._expression_codegen.visit(node.value)}$" ) def visit_For(self, node: ast.For) -> str: """Visit a For node.""" if len(node.orelse) != 0: raise exceptions.LatexifyNotSupportedError( "For statement with the else clause is not supported" ) target_latex = self._expression_codegen.visit(node.target) iter_latex = self._expression_codegen.visit(node.iter) with self._increment_level(): body_latex = "\n".join(self.visit(stmt) for stmt in node.body) return ( self._add_indent(f"\\For{{${target_latex} \\in {iter_latex}$}}\n") + f"{body_latex}\n" + self._add_indent("\\EndFor") ) # TODO(ZibingZhang): support nested functions def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" name_latex = self._identifier_converter.convert(node.name)[0] # Arguments arg_strs = [ self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args ] latex = self._add_indent("\\begin{algorithmic}\n") with self._increment_level(): latex += self._add_indent( f"\\Function{{{name_latex}}}{{${', '.join(arg_strs)}$}}\n" ) with self._increment_level(): # Body body_strs: list[str] = [self.visit(stmt) for stmt in node.body] body_latex = "\n".join(body_strs) latex += f"{body_latex}\n" latex += self._add_indent("\\EndFunction\n") return latex + self._add_indent(r"\end{algorithmic}") # TODO(ZibingZhang): support \ELSIF def visit_If(self, node: ast.If) -> str: """Visit an If node.""" cond_latex = self._expression_codegen.visit(node.test) with self._increment_level(): body_latex = "\n".join(self.visit(stmt) for stmt in node.body) latex = self._add_indent(f"\\If{{${cond_latex}$}}\n" + body_latex) if node.orelse: latex += "\n" + self._add_indent("\\Else\n") with self._increment_level(): latex += "\n".join(self.visit(stmt) for stmt in node.orelse) return f"{latex}\n" + self._add_indent(r"\EndIf") def visit_Module(self, node: ast.Module) -> str: """Visit a Module node.""" return self.visit(node.body[0]) def visit_Return(self, node: ast.Return) -> str: """Visit a Return node.""" return ( self._add_indent( rf"\State \Return ${self._expression_codegen.visit(node.value)}$" ) if node.value is not None else self._add_indent(r"\State \Return") ) def visit_While(self, node: ast.While) -> str: """Visit a While node.""" if node.orelse: raise exceptions.LatexifyNotSupportedError( "While statement with the else clause is not supported" ) cond_latex = self._expression_codegen.visit(node.test) with self._increment_level(): body_latex = "\n".join(self.visit(stmt) for stmt in node.body) return ( self._add_indent(f"\\While{{${cond_latex}$}}\n") + f"{body_latex}\n" + self._add_indent(r"\EndWhile") ) def visit_Pass(self, node: ast.Pass) -> str: """Visit a Pass node.""" return self._add_indent(r"\State $\mathbf{pass}$") def visit_Break(self, node: ast.Break) -> str: """Visit a Break node.""" return self._add_indent(r"\State $\mathbf{break}$") def visit_Continue(self, node: ast.Continue) -> str: """Visit a Continue node.""" return self._add_indent(r"\State $\mathbf{continue}$") @contextlib.contextmanager def _increment_level(self) -> Generator[None, None, None]: """Context manager controlling indent level.""" self._indent_level += 1 yield self._indent_level -= 1 def _add_indent(self, line: str) -> str: """Adds an indent before the line. Args: line: The line to add an indent to. """ return self._indent_level * self._SPACES_PER_INDENT * " " + line class IPythonAlgorithmicCodegen(ast.NodeVisitor): """Codegen for single algorithms targeting IPython. This codegen works for Module with single FunctionDef node to generate a single LaTeX expression of the given algorithm. """ _EM_PER_INDENT = 1 _LINE_BREAK = r" \\ " _identifier_converter: identifier_converter.IdentifierConverter _indent_level: int def __init__( self, *, use_math_symbols: bool = False, use_set_symbols: bool = False, escape_underscores: bool = True, ) -> None: """Initializer. Args: use_math_symbols: Whether to convert identifiers with a math symbol surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). use_set_symbols: Whether to use set symbols or not. """ self._expression_codegen = expression_codegen.ExpressionCodegen( use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols, escape_underscores=escape_underscores, ) self._identifier_converter = identifier_converter.IdentifierConverter( use_math_symbols=use_math_symbols, escape_underscores=escape_underscores ) self._indent_level = 0 def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( f"Unsupported AST: {type(node).__name__}" ) def visit_Assign(self, node: ast.Assign) -> str: """Visit an Assign node.""" operands: list[str] = [ self._expression_codegen.visit(target) for target in node.targets ] operands.append(self._expression_codegen.visit(node.value)) operands_latex = r" \gets ".join(operands) return self._add_indent(operands_latex) def visit_Expr(self, node: ast.Expr) -> str: """Visit an Expr node.""" return self._add_indent(self._expression_codegen.visit(node.value)) def visit_For(self, node: ast.For) -> str: """Visit a For node.""" if len(node.orelse) != 0: raise exceptions.LatexifyNotSupportedError( "For statement with the else clause is not supported" ) target_latex = self._expression_codegen.visit(node.target) iter_latex = self._expression_codegen.visit(node.iter) with self._increment_level(): body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body) return ( self._add_indent(r"\mathbf{for}") + rf" \ {target_latex} \in {iter_latex} \ \mathbf{{do}}{self._LINE_BREAK}" + f"{body_latex}{self._LINE_BREAK}" + self._add_indent(r"\mathbf{end \ for}") ) # TODO(ZibingZhang): support nested functions def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" name_latex = self._identifier_converter.convert(node.name)[0] # Arguments args_latex = [ self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args ] # Body with self._increment_level(): body_stmts_latex: list[str] = [self.visit(stmt) for stmt in node.body] body_latex = self._LINE_BREAK.join(body_stmts_latex) return ( r"\begin{array}{l} " + self._add_indent(r"\mathbf{function}") + rf" \ {name_latex}({', '.join(args_latex)})" + f"{self._LINE_BREAK}{body_latex}{self._LINE_BREAK}" + self._add_indent(r"\mathbf{end \ function}") + r" \end{array}" ) # TODO(ZibingZhang): support \ELSIF def visit_If(self, node: ast.If) -> str: """Visit an If node.""" cond_latex = self._expression_codegen.visit(node.test) with self._increment_level(): body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body) latex = self._add_indent( rf"\mathbf{{if}} \ {cond_latex}{self._LINE_BREAK}{body_latex}" ) if node.orelse: latex += self._LINE_BREAK + self._add_indent(r"\mathbf{else} \\ ") with self._increment_level(): latex += self._LINE_BREAK.join(self.visit(stmt) for stmt in node.orelse) return latex + self._LINE_BREAK + self._add_indent(r"\mathbf{end \ if}") def visit_Module(self, node: ast.Module) -> str: """Visit a Module node.""" return self.visit(node.body[0]) def visit_Return(self, node: ast.Return) -> str: """Visit a Return node.""" return ( self._add_indent(r"\mathbf{return} \ ") + self._expression_codegen.visit(node.value) if node.value is not None else self._add_indent(r"\mathbf{return}") ) def visit_While(self, node: ast.While) -> str: """Visit a While node.""" if node.orelse: raise exceptions.LatexifyNotSupportedError( "While statement with the else clause is not supported" ) cond_latex = self._expression_codegen.visit(node.test) with self._increment_level(): body_latex = self._LINE_BREAK.join(self.visit(stmt) for stmt in node.body) return ( self._add_indent(r"\mathbf{while} \ ") + f"{cond_latex}{self._LINE_BREAK}{body_latex}{self._LINE_BREAK}" + self._add_indent(r"\mathbf{end \ while}") ) def visit_Pass(self, node: ast.Pass) -> str: """Visit a Pass node.""" return self._add_indent(r"\mathbf{pass}") def visit_Break(self, node: ast.Break) -> str: """Visit a Break node.""" return self._add_indent(r"\mathbf{break}") def visit_Continue(self, node: ast.Continue) -> str: """Visit a Continue node.""" return self._add_indent(r"\mathbf{continue}") @contextlib.contextmanager def _increment_level(self) -> Generator[None, None, None]: """Context manager controlling indent level.""" self._indent_level += 1 yield self._indent_level -= 1 def _add_indent(self, line: str) -> str: """Adds an indent before the line. Args: line: The line to add an indent to. """ return ( rf"\hspace{{{self._indent_level * self._EM_PER_INDENT}em}} {line}" if self._indent_level > 0 else line ) ================================================ FILE: src/latexify/codegen/algorithmic_codegen_test.py ================================================ """Tests for latexify.codegen.algorithmic_codegen.""" from __future__ import annotations import ast import textwrap import pytest from latexify import exceptions from latexify.codegen import algorithmic_codegen def test_generic_visit() -> None: class UnknownNode(ast.AST): pass with pytest.raises( exceptions.LatexifyNotSupportedError, match=r"^Unsupported AST: UnknownNode$", ): algorithmic_codegen.AlgorithmicCodegen().visit(UnknownNode()) @pytest.mark.parametrize( "code,latex", [ ( "x = 3", r"\State $x \gets 3$", ), ( "a = b = 0", r"\State $a \gets b \gets 0$", ), ], ) def test_visit_assign(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.Assign) assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ( "for i in {1}: x = i", r""" \For{$i \in \mathopen{}\left\{ 1 \mathclose{}\right\}$} \State $x \gets i$ \EndFor """, ), ], ) def test_visit_for(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.For) assert ( algorithmic_codegen.AlgorithmicCodegen().visit(node) == textwrap.dedent(latex).strip() ) @pytest.mark.parametrize( "code,latex", [ ( "def f(x): return x", r""" \begin{algorithmic} \Function{f}{$x$} \State \Return $x$ \EndFunction \end{algorithmic} """, ), ( "def xyz(a, b, c): return 3", r""" \begin{algorithmic} \Function{xyz}{$a, b, c$} \State \Return $3$ \EndFunction \end{algorithmic} """, ), ], ) def test_visit_functiondef(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.FunctionDef) assert ( algorithmic_codegen.AlgorithmicCodegen().visit(node) == textwrap.dedent(latex).strip() ) @pytest.mark.parametrize( "code,latex", [ ( "if x < y: return x", r""" \If{$x < y$} \State \Return $x$ \EndIf """, ), ( "if True: x\nelse: y", r""" \If{$\mathrm{True}$} \State $x$ \Else \State $y$ \EndIf """, ), ], ) def test_visit_if(code: str, latex: str) -> None: node = ast.parse(code).body[0] assert isinstance(node, ast.If) assert ( algorithmic_codegen.AlgorithmicCodegen().visit(node) == textwrap.dedent(latex).strip() ) @pytest.mark.parametrize( "code,latex", [ ( "return x + y", r"\State \Return $x + y$", ), ( "return", r"\State \Return", ), ], ) def test_visit_return(code: str, latex: str) -> None: node = ast.parse(code).body[0] assert isinstance(node, ast.Return) assert algorithmic_codegen.AlgorithmicCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ( "while x < y: x = x + 1", r""" \While{$x < y$} \State $x \gets x + 1$ \EndWhile """, ) ], ) def test_visit_while(code: str, latex: str) -> None: node = ast.parse(code).body[0] assert isinstance(node, ast.While) assert ( algorithmic_codegen.AlgorithmicCodegen().visit(node) == textwrap.dedent(latex).strip() ) def test_visit_while_with_else() -> None: node = ast.parse( textwrap.dedent( """ while True: x = x else: x = y """ ) ).body[0] assert isinstance(node, ast.While) with pytest.raises( exceptions.LatexifyNotSupportedError, match="^While statement with the else clause is not supported$", ): algorithmic_codegen.AlgorithmicCodegen().visit(node) def test_visit_pass() -> None: node = ast.parse("pass").body[0] assert isinstance(node, ast.Pass) assert ( algorithmic_codegen.AlgorithmicCodegen().visit(node) == r"\State $\mathbf{pass}$" ) def test_visit_break() -> None: node = ast.parse("break").body[0] assert isinstance(node, ast.Break) assert ( algorithmic_codegen.AlgorithmicCodegen().visit(node) == r"\State $\mathbf{break}$" ) def test_visit_continue() -> None: node = ast.parse("continue").body[0] assert isinstance(node, ast.Continue) assert ( algorithmic_codegen.AlgorithmicCodegen().visit(node) == r"\State $\mathbf{continue}$" ) @pytest.mark.parametrize( "code,latex", [ ("x = 3", r"x \gets 3"), ("a = b = 0", r"a \gets b \gets 0"), ], ) def test_visit_assign_ipython(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.Assign) assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ( "for i in {1}: x = i", ( r"\mathbf{for} \ i \in \mathopen{}\left\{ 1 \mathclose{}\right\}" r" \ \mathbf{do} \\" r" \hspace{1em} x \gets i \\" r" \mathbf{end \ for}" ), ), ], ) def test_visit_for_ipython(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.For) assert ( algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == textwrap.dedent(latex).strip() ) @pytest.mark.parametrize( "code,latex", [ ( "def f(x): return x", ( r"\begin{array}{l}" r" \mathbf{function}" r" \ f(x) \\" r" \hspace{1em} \mathbf{return} \ x \\" r" \mathbf{end \ function}" r" \end{array}" ), ), ( "def f(a, b, c): return 3", ( r"\begin{array}{l}" r" \mathbf{function}" r" \ f(a, b, c) \\" r" \hspace{1em} \mathbf{return} \ 3 \\" r" \mathbf{end \ function}" r" \end{array}" ), ), ], ) def test_visit_functiondef_ipython(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.FunctionDef) assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ( "if x < y: return x", ( r"\mathbf{if} \ x < y \\" r" \hspace{1em} \mathbf{return} \ x \\" r" \mathbf{end \ if}" ), ), ( "if True: x\nelse: y", ( r"\mathbf{if} \ \mathrm{True} \\" r" \hspace{1em} x \\" r" \mathbf{else} \\" r" \hspace{1em} y \\" r" \mathbf{end \ if}" ), ), ], ) def test_visit_if_ipython(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.If) assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ( "return x + y", r"\mathbf{return} \ x + y", ), ( "return", r"\mathbf{return}", ), ], ) def test_visit_return_ipython(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.Return) assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ( "while x < y: x = x + 1", ( r"\mathbf{while} \ x < y \\" r" \hspace{1em} x \gets x + 1 \\" r" \mathbf{end \ while}" ), ) ], ) def test_visit_while_ipython(code: str, latex: str) -> None: node = ast.parse(textwrap.dedent(code)).body[0] assert isinstance(node, ast.While) assert algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == latex def test_visit_while_with_else_ipython() -> None: node = ast.parse( textwrap.dedent( """ while True: x = x else: x = y """ ) ).body[0] assert isinstance(node, ast.While) with pytest.raises( exceptions.LatexifyNotSupportedError, match="^While statement with the else clause is not supported$", ): algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) def test_visit_pass_ipython() -> None: node = ast.parse("pass").body[0] assert isinstance(node, ast.Pass) assert ( algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == r"\mathbf{pass}" ) def test_visit_break_ipython() -> None: node = ast.parse("break").body[0] assert isinstance(node, ast.Break) assert ( algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == r"\mathbf{break}" ) def test_visit_continue_ipython() -> None: node = ast.parse("continue").body[0] assert isinstance(node, ast.Continue) assert ( algorithmic_codegen.IPythonAlgorithmicCodegen().visit(node) == r"\mathbf{continue}" ) ================================================ FILE: src/latexify/codegen/codegen_utils.py ================================================ from typing import Any from latexify import exceptions def convert_constant(value: Any) -> str: """Helper to convert constant values to LaTeX. Args: value: A constant value. Returns: The LaTeX representation of `value`. """ if value is None or isinstance(value, bool): return r"\mathrm{" + str(value) + "}" if isinstance(value, (int, float, complex)): # TODO(odashi): Support other symbols for the imaginary unit than j. return str(value) if isinstance(value, str): return r'\textrm{"' + value + '"}' if isinstance(value, bytes): return r"\textrm{" + str(value) + "}" if value is ...: return r"\cdots" raise exceptions.LatexifyNotSupportedError( f"Unrecognized constant: {type(value).__name__}" ) ================================================ FILE: src/latexify/codegen/codegen_utils_test.py ================================================ """Tests for latexify.codegen.codegen_utils.""" from __future__ import annotations from typing import Any import pytest from latexify import exceptions from latexify.codegen.codegen_utils import convert_constant @pytest.mark.parametrize( "constant,latex", [ (None, r"\mathrm{None}"), (True, r"\mathrm{True}"), (False, r"\mathrm{False}"), (123, "123"), (456.789, "456.789"), (-3 + 4j, "(-3+4j)"), ("string", r'\textrm{"string"}'), (..., r"\cdots"), ], ) def test_convert_constant(constant: Any, latex: str) -> None: assert convert_constant(constant) == latex def test_convert_constant_unsupported_constant() -> None: with pytest.raises( exceptions.LatexifyNotSupportedError, match="^Unrecognized constant: " ): convert_constant({}) ================================================ FILE: src/latexify/codegen/expression_codegen.py ================================================ """Codegen for single expressions.""" from __future__ import annotations import ast import re from latexify import analyzers, ast_utils, exceptions from latexify.codegen import codegen_utils, expression_rules, identifier_converter class ExpressionCodegen(ast.NodeVisitor): """Codegen for single expressions.""" _identifier_converter: identifier_converter.IdentifierConverter _bin_op_rules: dict[type[ast.operator], expression_rules.BinOpRule] _compare_ops: dict[type[ast.cmpop], str] def __init__( self, *, use_math_symbols: bool = False, use_set_symbols: bool = False, escape_underscores: bool = True, ) -> None: """Initializer. Args: use_math_symbols: Whether to convert identifiers with a math symbol surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). use_set_symbols: Whether to use set symbols or not. """ self._identifier_converter = identifier_converter.IdentifierConverter( use_math_symbols=use_math_symbols, escape_underscores=escape_underscores ) self._bin_op_rules = ( expression_rules.SET_BIN_OP_RULES if use_set_symbols else expression_rules.BIN_OP_RULES ) self._compare_ops = ( expression_rules.SET_COMPARE_OPS if use_set_symbols else expression_rules.COMPARE_OPS ) def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( f"Unsupported AST: {type(node).__name__}" ) def visit_Tuple(self, node: ast.Tuple) -> str: """Visit a Tuple node.""" elts = [self.visit(elt) for elt in node.elts] return r"\mathopen{}\left( " + r", ".join(elts) + r" \mathclose{}\right)" def visit_List(self, node: ast.List) -> str: """Visit a List node.""" elts = [self.visit(elt) for elt in node.elts] return r"\mathopen{}\left[ " + r", ".join(elts) + r" \mathclose{}\right]" def visit_Set(self, node: ast.Set) -> str: """Visit a Set node.""" elts = [self.visit(elt) for elt in node.elts] return r"\mathopen{}\left\{ " + r", ".join(elts) + r" \mathclose{}\right\}" def visit_ListComp(self, node: ast.ListComp) -> str: """Visit a ListComp node.""" generators = [self.visit(comp) for comp in node.generators] return ( r"\mathopen{}\left[ " + self.visit(node.elt) + r" \mid " + ", ".join(generators) + r" \mathclose{}\right]" ) def visit_SetComp(self, node: ast.SetComp) -> str: """Visit a SetComp node.""" generators = [self.visit(comp) for comp in node.generators] return ( r"\mathopen{}\left\{ " + self.visit(node.elt) + r" \mid " + ", ".join(generators) + r" \mathclose{}\right\}" ) def visit_comprehension(self, node: ast.comprehension) -> str: """Visit a comprehension node.""" target = rf"{self.visit(node.target)} \in {self.visit(node.iter)}" if not node.ifs: # Returns the source without parenthesis. return target conds = [target] + [self.visit(cond) for cond in node.ifs] wrapped = [r"\mathopen{}\left( " + s + r" \mathclose{}\right)" for s in conds] return r" \land ".join(wrapped) def _generate_sum_prod(self, node: ast.Call) -> str | None: """Generates sum/prod expression. Args: node: ast.Call node containing the sum/prod invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. """ if not node.args or not isinstance(node.args[0], ast.GeneratorExp): return None name = ast_utils.extract_function_name_or_none(node) assert name in ("fsum", "sum", "prod") command = { "fsum": r"\sum", "sum": r"\sum", "prod": r"\prod", }[name] elt, scripts = self._get_sum_prod_info(node.args[0]) scripts_str = [rf"{command}_{{{lo}}}^{{{up}}}" for lo, up in scripts] return ( " ".join(scripts_str) + rf" \mathopen{{}}\left({{{elt}}}\mathclose{{}}\right)" ) def _generate_matrix(self, node: ast.Call) -> str | None: """Generates matrix expression. Args: node: ast.Call node containing the ndarray invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. """ def generate_matrix_from_array(data: list[list[str]]) -> str: """Helper to generate a bmatrix environment.""" contents = r" \\ ".join(" & ".join(row) for row in data) return r"\begin{bmatrix} " + contents + r" \end{bmatrix}" arg = node.args[0] if not isinstance(arg, ast.List) or not arg.elts: # Not an array or no rows return None row0 = arg.elts[0] if not isinstance(row0, ast.List): # Maybe 1 x N array return generate_matrix_from_array([[self.visit(x) for x in arg.elts]]) if not row0.elts: # No columns return None ncols = len(row0.elts) rows: list[list[str]] = [] for row in arg.elts: if not isinstance(row, ast.List) or len(row.elts) != ncols: # Length mismatch return None rows.append([self.visit(x) for x in row.elts]) return generate_matrix_from_array(rows) def _generate_zeros(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.zeros. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. """ name = ast_utils.extract_function_name_or_none(node) assert name == "zeros" if len(node.args) != 1: return None # All args to np.zeros should be numeric. if isinstance(node.args[0], ast.Tuple): dims = [ast_utils.extract_int_or_none(x) for x in node.args[0].elts] if any(x is None for x in dims): return None if not dims: return "0" if len(dims) == 1: dims = [1, dims[0]] dims_latex = r" \times ".join(str(x) for x in dims) else: dim = ast_utils.extract_int_or_none(node.args[0]) if not isinstance(dim, int): return None # 1 x N array of zeros dims_latex = rf"1 \times {dim}" return rf"\mathbf{{0}}^{{{dims_latex}}}" def _generate_identity(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.identity. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. """ name = ast_utils.extract_function_name_or_none(node) assert name == "identity" if len(node.args) != 1: return None ndims = ast_utils.extract_int_or_none(node.args[0]) if ndims is None: return None return rf"\mathbf{{I}}_{{{ndims}}}" def _generate_transpose(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.transpose. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. Raises: LatexifyError: Unsupported argument type given. """ name = ast_utils.extract_function_name_or_none(node) assert name == "transpose" if len(node.args) != 1: return None func_arg = node.args[0] if isinstance(func_arg, ast.Name): return rf"\mathbf{{{func_arg.id}}}^\intercal" else: return None def _generate_determinant(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.linalg.det. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. Raises: LatexifyError: Unsupported argument type given. """ name = ast_utils.extract_function_name_or_none(node) assert name == "det" if len(node.args) != 1: return None func_arg = node.args[0] if isinstance(func_arg, ast.Name): arg_id = rf"\mathbf{{{func_arg.id}}}" return rf"\det \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)" elif isinstance(func_arg, ast.List): matrix = self._generate_matrix(node) return rf"\det \mathopen{{}}\left( {matrix} \mathclose{{}}\right)" return None def _generate_matrix_rank(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.linalg.matrix_rank. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. Raises: LatexifyError: Unsupported argument type given. """ name = ast_utils.extract_function_name_or_none(node) assert name == "matrix_rank" if len(node.args) != 1: return None func_arg = node.args[0] if isinstance(func_arg, ast.Name): arg_id = rf"\mathbf{{{func_arg.id}}}" return ( rf"\mathrm{{rank}} \mathopen{{}}\left( {arg_id} \mathclose{{}}\right)" ) elif isinstance(func_arg, ast.List): matrix = self._generate_matrix(node) return ( rf"\mathrm{{rank}} \mathopen{{}}\left( {matrix} \mathclose{{}}\right)" ) return None def _generate_matrix_power(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.linalg.matrix_power. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. Raises: LatexifyError: Unsupported argument type given. """ name = ast_utils.extract_function_name_or_none(node) assert name == "matrix_power" if len(node.args) != 2: return None func_arg = node.args[0] power_arg = node.args[1] if isinstance(power_arg, ast.Num): if isinstance(func_arg, ast.Name): return rf"\mathbf{{{func_arg.id}}}^{{{power_arg.n}}}" elif isinstance(func_arg, ast.List): matrix = self._generate_matrix(node) if matrix is not None: return rf"{matrix}^{{{power_arg.n}}}" return None def _generate_inv(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.linalg.inv. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. Raises: LatexifyError: Unsupported argument type given. """ name = ast_utils.extract_function_name_or_none(node) assert name == "inv" if len(node.args) != 1: return None func_arg = node.args[0] if isinstance(func_arg, ast.Name): return rf"\mathbf{{{func_arg.id}}}^{{-1}}" elif isinstance(func_arg, ast.List): return rf"{self._generate_matrix(node)}^{{-1}}" return None def _generate_pinv(self, node: ast.Call) -> str | None: """Generates LaTeX for numpy.linalg.pinv. Args: node: ast.Call node containing the appropriate method invocation. Returns: Generated LaTeX, or None if the node has unsupported syntax. Raises: LatexifyError: Unsupported argument type given. """ name = ast_utils.extract_function_name_or_none(node) assert name == "pinv" if len(node.args) != 1: return None func_arg = node.args[0] if isinstance(func_arg, ast.Name): return rf"\mathbf{{{func_arg.id}}}^{{+}}" elif isinstance(func_arg, ast.List): return rf"{self._generate_matrix(node)}^{{+}}" return None def visit_Call(self, node: ast.Call) -> str: """Visit a Call node.""" func_name = ast_utils.extract_function_name_or_none(node) # Special treatments for some functions. # TODO(odashi): Move these functions to some separate utility. if func_name in ("fsum", "sum", "prod"): special_latex = self._generate_sum_prod(node) elif func_name in ("array", "ndarray"): special_latex = self._generate_matrix(node) elif func_name == "zeros": special_latex = self._generate_zeros(node) elif func_name == "identity": special_latex = self._generate_identity(node) elif func_name == "transpose": special_latex = self._generate_transpose(node) elif func_name == "det": special_latex = self._generate_determinant(node) elif func_name == "matrix_rank": special_latex = self._generate_matrix_rank(node) elif func_name == "matrix_power": special_latex = self._generate_matrix_power(node) elif func_name == "inv": special_latex = self._generate_inv(node) elif func_name == "pinv": special_latex = self._generate_pinv(node) else: special_latex = None if special_latex is not None: return special_latex # Obtains the codegen rule. rule = ( expression_rules.BUILTIN_FUNCS.get(func_name) if func_name is not None else None ) if rule is None: rule = expression_rules.FunctionRule(self.visit(node.func)) if rule.is_unary and len(node.args) == 1: # Unary function. Applies the same wrapping policy with the unary operators. precedence = expression_rules.get_precedence(node) arg = node.args[0] # NOTE(odashi): # Factorial "x!" is treated as a special case: it requires both inner/outer # parentheses for correct interpretation. force_wrap_factorial = isinstance(arg, ast.Call) and ( func_name == "factorial" or ast_utils.extract_function_name_or_none(arg) == "factorial" ) # Note(odashi): # Wrapping is also required if the argument is pow. # https://github.com/google/latexify_py/issues/189 force_wrap_pow = isinstance(arg, ast.BinOp) and isinstance(arg.op, ast.Pow) arg_latex = self._wrap_operand( arg, precedence, force_wrap_factorial or force_wrap_pow ) elements = [rule.left, arg_latex, rule.right] else: arg_latex = ", ".join(self.visit(arg) for arg in node.args) if rule.is_wrapped: elements = [rule.left, arg_latex, rule.right] else: elements = [ rule.left, r"\mathopen{}\left(", arg_latex, r"\mathclose{}\right)", rule.right, ] return " ".join(x for x in elements if x) def visit_Attribute(self, node: ast.Attribute) -> str: """Visit an Attribute node.""" vstr = self.visit(node.value) astr = self._identifier_converter.convert(node.attr)[0] return vstr + "." + astr def visit_Name(self, node: ast.Name) -> str: """Visit a Name node.""" return self._identifier_converter.convert(node.id)[0] # From Python 3.8 def visit_Constant(self, node: ast.Constant) -> str: """Visit a Constant node.""" return codegen_utils.convert_constant(node.value) # Until Python 3.7 def visit_Num(self, node: ast.Num) -> str: """Visit a Num node.""" return codegen_utils.convert_constant(node.n) # Until Python 3.7 def visit_Str(self, node: ast.Str) -> str: """Visit a Str node.""" return codegen_utils.convert_constant(node.s) # Until Python 3.7 def visit_Bytes(self, node: ast.Bytes) -> str: """Visit a Bytes node.""" return codegen_utils.convert_constant(node.s) # Until Python 3.7 def visit_NameConstant(self, node: ast.NameConstant) -> str: """Visit a NameConstant node.""" return codegen_utils.convert_constant(node.value) # Until Python 3.7 def visit_Ellipsis(self, node: ast.Ellipsis) -> str: """Visit an Ellipsis node.""" return codegen_utils.convert_constant(...) def _wrap_operand( self, child: ast.expr, parent_prec: int, force_wrap: bool = False ) -> str: """Wraps the operand subtree with parentheses. Args: child: Operand subtree. parent_prec: Precedence of the parent operator. force_wrap: Whether to wrap the operand or not when the precedence is equal. Returns: LaTeX form of `child`, with or without surrounding parentheses. """ latex = self.visit(child) child_prec = expression_rules.get_precedence(child) if force_wrap or child_prec < parent_prec: return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" return latex def _wrap_binop_operand( self, child: ast.expr, parent_prec: int, operand_rule: expression_rules.BinOperandRule, ) -> str: """Wraps the operand subtree of BinOp with parentheses. Args: child: Operand subtree. parent_prec: Precedence of the parent operator. operand_rule: Syntax rule of this operand. Returns: LaTeX form of the `child`, with or without surrounding parentheses. """ if not operand_rule.wrap: return self.visit(child) if isinstance(child, ast.Call): child_fn_name = ast_utils.extract_function_name_or_none(child) rule = ( expression_rules.BUILTIN_FUNCS.get(child_fn_name) if child_fn_name is not None else None ) if rule is not None and rule.is_wrapped: return self.visit(child) if not isinstance(child, ast.BinOp): return self._wrap_operand(child, parent_prec) latex = self.visit(child) if expression_rules.BIN_OP_RULES[type(child.op)].is_wrapped: return latex child_prec = expression_rules.get_precedence(child) if child_prec > parent_prec or ( child_prec == parent_prec and not operand_rule.force ): return latex return rf"\mathopen{{}}\left( {latex} \mathclose{{}}\right)" _l_bracket_pattern = re.compile(r"^\\mathopen.*") _r_bracket_pattern = re.compile(r".*\\mathclose[^ ]+$") _r_word_pattern = re.compile(r"\\mathrm\{[^ ]+\}$") def _should_remove_multiply_op( self, l_latex: str, r_latex: str, l_expr: ast.expr, r_expr: ast.expr ): """Determine whether the multiply operator should be removed or not. See also: https://github.com/google/latexify_py/issues/89#issuecomment-1344967636 This is an ad-hoc implementation. This function doesn't fully implements the above requirements, but only essential ones necessary to release v0.3. """ # NOTE(odashi): For compatibility with Python 3.7, we compare the generated # caracter type directly to determine the "numeric" type. if isinstance(l_expr, ast.Call): l_type = "f" elif self._r_bracket_pattern.match(l_latex): l_type = "b" elif self._r_word_pattern.match(l_latex): l_type = "w" elif l_latex[-1].isnumeric(): l_type = "n" else: le = l_expr while True: if isinstance(le, ast.UnaryOp): le = le.operand elif isinstance(le, ast.BinOp): le = le.right elif isinstance(le, ast.Compare): le = le.comparators[-1] elif isinstance(le, ast.BoolOp): le = le.values[-1] else: break l_type = "a" if isinstance(le, ast.Name) and len(le.id) == 1 else "m" if isinstance(r_expr, ast.Call): r_type = "f" elif self._l_bracket_pattern.match(r_latex): r_type = "b" elif r_latex.startswith("\\mathrm"): r_type = "w" elif r_latex[0].isnumeric(): r_type = "n" else: re = r_expr while True: if isinstance(re, ast.UnaryOp): if isinstance(re.op, ast.USub): # NOTE(odashi): Unary "-" always require \cdot. return False re = re.operand elif isinstance(re, ast.BinOp): re = re.left elif isinstance(re, ast.Compare): re = re.left elif isinstance(re, ast.BoolOp): re = re.values[0] else: break r_type = "a" if isinstance(re, ast.Name) and len(re.id) == 1 else "m" if r_type == "n": return False if l_type in "bn": return True if l_type in "am" and r_type in "am": return True return False def visit_BinOp(self, node: ast.BinOp) -> str: """Visit a BinOp node.""" prec = expression_rules.get_precedence(node) rule = self._bin_op_rules[type(node.op)] lhs = self._wrap_binop_operand(node.left, prec, rule.operand_left) rhs = self._wrap_binop_operand(node.right, prec, rule.operand_right) if type(node.op) in [ast.Mult, ast.MatMult]: if self._should_remove_multiply_op(lhs, rhs, node.left, node.right): return f"{rule.latex_left}{lhs} {rhs}{rule.latex_right}" return f"{rule.latex_left}{lhs}{rule.latex_middle}{rhs}{rule.latex_right}" def visit_UnaryOp(self, node: ast.UnaryOp) -> str: """Visit a UnaryOp node.""" latex = self._wrap_operand(node.operand, expression_rules.get_precedence(node)) return expression_rules.UNARY_OPS[type(node.op)] + latex def visit_Compare(self, node: ast.Compare) -> str: """Visit a Compare node.""" parent_prec = expression_rules.get_precedence(node) lhs = self._wrap_operand(node.left, parent_prec) ops = [self._compare_ops[type(x)] for x in node.ops] rhs = [self._wrap_operand(x, parent_prec) for x in node.comparators] ops_rhs = [f" {o} {r}" for o, r in zip(ops, rhs)] return lhs + "".join(ops_rhs) def visit_BoolOp(self, node: ast.BoolOp) -> str: """Visit a BoolOp node.""" parent_prec = expression_rules.get_precedence(node) values = [self._wrap_operand(x, parent_prec) for x in node.values] op = f" {expression_rules.BOOL_OPS[type(node.op)]} " return op.join(values) def visit_IfExp(self, node: ast.IfExp) -> str: """Visit an IfExp node""" latex = r"\left\{ \begin{array}{ll} " current_expr: ast.expr = node while isinstance(current_expr, ast.IfExp): cond_latex = self.visit(current_expr.test) true_latex = self.visit(current_expr.body) latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ " current_expr = current_expr.orelse latex += self.visit(current_expr) return latex + r", & \mathrm{otherwise} \end{array} \right." def _get_sum_prod_range(self, node: ast.comprehension) -> tuple[str, str] | None: """Helper to process range(...) for sum and prod functions. Args: node: comprehension node to be analyzed. Returns: Tuple of following strings: - lower_rhs - upper which are used in _get_sum_prod_info, or None if the analysis failed. """ if not ( isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name) and node.iter.func.id == "range" ): return None try: range_info = analyzers.analyze_range(node.iter) except exceptions.LatexifyError: return None if ( # Only accepts ascending order with step size 1. range_info.step_int != 1 or ( range_info.start_int is not None and range_info.stop_int is not None and range_info.start_int >= range_info.stop_int ) ): return None if range_info.start_int is None: lower_rhs = self.visit(range_info.start) else: lower_rhs = str(range_info.start_int) if range_info.stop_int is None: upper = self.visit(analyzers.reduce_stop_parameter(range_info.stop)) else: upper = str(range_info.stop_int - 1) return lower_rhs, upper def _get_sum_prod_info( self, node: ast.GeneratorExp ) -> tuple[str, list[tuple[str, str]]]: r"""Process GeneratorExp for sum and prod functions. Args: node: GeneratorExp node to be analyzed. Returns: Tuple of following strings: - elt - scripts which are used to represent sum/prod operators as follows: \sum_{scripts[0][0]}^{scripts[0][1]} \sum_{scripts[1][0]}^{scripts[1][1]} ... {elt} Raises: LateixfyError: Unsupported AST is given. """ elt = self.visit(node.elt) scripts: list[tuple[str, str]] = [] for comp in node.generators: range_args = self._get_sum_prod_range(comp) if range_args is not None and not comp.ifs: target = self.visit(comp.target) lower_rhs, upper = range_args lower = f"{target} = {lower_rhs}" else: lower = self.visit(comp) # Use a usual comprehension form. upper = "" scripts.append((lower, upper)) return elt, scripts # Until 3.8 def visit_Index(self, node: ast.Index) -> str: """Visit an Index node.""" return self.visit(node.value) # type: ignore[attr-defined] def _convert_nested_subscripts(self, node: ast.Subscript) -> tuple[str, list[str]]: """Helper function to convert nested subscription. This function converts x[i][j][...] to "x" and ["i", "j", ...] Args: node: ast.Subscript node to be converted. Returns: Tuple of following strings: - The root value of the subscription. - Sequence of incices. """ if isinstance(node.value, ast.Subscript): value, indices = self._convert_nested_subscripts(node.value) else: value = self.visit(node.value) indices = [] indices.append(self.visit(node.slice)) return value, indices def visit_Subscript(self, node: ast.Subscript) -> str: """Visitor a Subscript node.""" value, indices = self._convert_nested_subscripts(node) # TODO(odashi): # "[i][j][...]" may be a possible representation as well as "i, j. ..." indices_str = ", ".join(indices) return f"{value}_{{{indices_str}}}" ================================================ FILE: src/latexify/codegen/expression_codegen_test.py ================================================ """Tests for latexify.codegen.expression_codegen.""" from __future__ import annotations import ast import pytest from latexify import ast_utils, exceptions from latexify.codegen import expression_codegen def test_generic_visit() -> None: class UnknownNode(ast.AST): pass with pytest.raises( exceptions.LatexifyNotSupportedError, match=r"^Unsupported AST: UnknownNode$", ): expression_codegen.ExpressionCodegen().visit(UnknownNode()) @pytest.mark.parametrize( "code,latex", [ ("()", r"\mathopen{}\left( \mathclose{}\right)"), ("(x,)", r"\mathopen{}\left( x \mathclose{}\right)"), ("(x, y)", r"\mathopen{}\left( x, y \mathclose{}\right)"), ("(x, y, z)", r"\mathopen{}\left( x, y, z \mathclose{}\right)"), ], ) def test_visit_tuple(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Tuple) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ("[]", r"\mathopen{}\left[ \mathclose{}\right]"), ("[x]", r"\mathopen{}\left[ x \mathclose{}\right]"), ("[x, y]", r"\mathopen{}\left[ x, y \mathclose{}\right]"), ("[x, y, z]", r"\mathopen{}\left[ x, y, z \mathclose{}\right]"), ], ) def test_visit_list(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.List) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ # TODO(odashi): Support set(). # ("set()", r"\mathopen{}\left\{ \mathclose{}\right\}"), ("{x}", r"\mathopen{}\left\{ x \mathclose{}\right\}"), ("{x, y}", r"\mathopen{}\left\{ x, y \mathclose{}\right\}"), ("{x, y, z}", r"\mathopen{}\left\{ x, y, z \mathclose{}\right\}"), ], ) def test_visit_set(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Set) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ("[i for i in n]", r"\mathopen{}\left[ i \mid i \in n \mathclose{}\right]"), ( "[i for i in n if i > 0]", r"\mathopen{}\left[ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \mathclose{}\right]", ), ( "[i for i in n if i > 0 if f(i)]", r"\mathopen{}\left[ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \land \mathopen{}\left( f \mathopen{}\left(" r" i \mathclose{}\right) \mathclose{}\right)" r" \mathclose{}\right]", ), ( "[i for k in n for i in k]", r"\mathopen{}\left[ i \mid k \in n, i \in k" r" \mathclose{}\right]", ), ( "[i for k in n for i in k if i > 0]", r"\mathopen{}\left[ i \mid" r" k \in n," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \mathclose{}\right]", ), ( "[i for k in n if f(k) for i in k if i > 0]", r"\mathopen{}\left[ i \mid" r" \mathopen{}\left( k \in n \mathclose{}\right)" r" \land \mathopen{}\left( f \mathopen{}\left(" r" k \mathclose{}\right) \mathclose{}\right)," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \mathclose{}\right]", ), ], ) def test_visit_listcomp(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.ListComp) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ("{i for i in n}", r"\mathopen{}\left\{ i \mid i \in n \mathclose{}\right\}"), ( "{i for i in n if i > 0}", r"\mathopen{}\left\{ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \mathclose{}\right\}", ), ( "{i for i in n if i > 0 if f(i)}", r"\mathopen{}\left\{ i \mid" r" \mathopen{}\left( i \in n \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \land \mathopen{}\left( f \mathopen{}\left(" r" i \mathclose{}\right) \mathclose{}\right)" r" \mathclose{}\right\}", ), ( "{i for k in n for i in k}", r"\mathopen{}\left\{ i \mid k \in n, i \in k" r" \mathclose{}\right\}", ), ( "{i for k in n for i in k if i > 0}", r"\mathopen{}\left\{ i \mid" r" k \in n," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \mathclose{}\right\}", ), ( "{i for k in n if f(k) for i in k if i > 0}", r"\mathopen{}\left\{ i \mid" r" \mathopen{}\left( k \in n \mathclose{}\right)" r" \land \mathopen{}\left( f \mathopen{}\left(" r" k \mathclose{}\right) \mathclose{}\right)," r" \mathopen{}\left( i \in k \mathclose{}\right)" r" \land \mathopen{}\left( i > 0 \mathclose{}\right)" r" \mathclose{}\right\}", ), ], ) def test_visit_setcomp(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.SetComp) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ("foo(x)", r"\mathrm{foo} \mathopen{}\left( x \mathclose{}\right)"), ("f(x)", r"f \mathopen{}\left( x \mathclose{}\right)"), ("f(-x)", r"f \mathopen{}\left( -x \mathclose{}\right)"), ("f(x + y)", r"f \mathopen{}\left( x + y \mathclose{}\right)"), ( "f(f(x))", r"f \mathopen{}\left(" r" f \mathopen{}\left( x \mathclose{}\right)" r" \mathclose{}\right)", ), ("f(sqrt(x))", r"f \mathopen{}\left( \sqrt{ x } \mathclose{}\right)"), ("f(sin(x))", r"f \mathopen{}\left( \sin x \mathclose{}\right)"), ("f(factorial(x))", r"f \mathopen{}\left( x ! \mathclose{}\right)"), ("f(x, y)", r"f \mathopen{}\left( x, y \mathclose{}\right)"), ("sqrt(x)", r"\sqrt{ x }"), ("sqrt(-x)", r"\sqrt{ -x }"), ("sqrt(x + y)", r"\sqrt{ x + y }"), ("sqrt(f(x))", r"\sqrt{ f \mathopen{}\left( x \mathclose{}\right) }"), ("sqrt(sqrt(x))", r"\sqrt{ \sqrt{ x } }"), ("sqrt(sin(x))", r"\sqrt{ \sin x }"), ("sqrt(factorial(x))", r"\sqrt{ x ! }"), ("sin(x)", r"\sin x"), ("sin(-x)", r"\sin \mathopen{}\left( -x \mathclose{}\right)"), ("sin(x + y)", r"\sin \mathopen{}\left( x + y \mathclose{}\right)"), ("sin(f(x))", r"\sin f \mathopen{}\left( x \mathclose{}\right)"), ("sin(sqrt(x))", r"\sin \sqrt{ x }"), ("sin(sin(x))", r"\sin \sin x"), ("sin(factorial(x))", r"\sin \mathopen{}\left( x ! \mathclose{}\right)"), ("factorial(x)", r"x !"), ("factorial(-x)", r"\mathopen{}\left( -x \mathclose{}\right) !"), ("factorial(x + y)", r"\mathopen{}\left( x + y \mathclose{}\right) !"), ( "factorial(f(x))", r"\mathopen{}\left(" r" f \mathopen{}\left( x \mathclose{}\right)" r" \mathclose{}\right) !", ), ("factorial(sqrt(x))", r"\mathopen{}\left( \sqrt{ x } \mathclose{}\right) !"), ("factorial(sin(x))", r"\mathopen{}\left( \sin x \mathclose{}\right) !"), ("factorial(factorial(x))", r"\mathopen{}\left( x ! \mathclose{}\right) !"), ], ) def test_visit_call(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Call) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ ("log(x)**2", r"\mathopen{}\left( \log x \mathclose{}\right)^{2}"), ("log(x**2)", r"\log \mathopen{}\left( x^{2} \mathclose{}\right)"), ( "log(x**2)**3", r"\mathopen{}\left(" r" \log \mathopen{}\left( x^{2} \mathclose{}\right)" r" \mathclose{}\right)^{3}", ), ], ) def test_visit_call_with_pow(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, (ast.Call, ast.BinOp)) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "src_suffix,dest_suffix", [ # No arguments ("()", r" \mathopen{}\left( \mathclose{}\right)"), # No comprehension ("(x)", r" x"), ( "([1, 2])", r" \mathopen{}\left[ 1, 2 \mathclose{}\right]", ), ( "({1, 2})", r" \mathopen{}\left\{ 1, 2 \mathclose{}\right\}", ), ("(f(x))", r" f \mathopen{}\left( x \mathclose{}\right)"), # Single comprehension ("(i for i in x)", r"_{i \in x}^{} \mathopen{}\left({i}\mathclose{}\right)"), ( "(i for i in [1, 2])", r"_{i \in \mathopen{}\left[ 1, 2 \mathclose{}\right]}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in {1, 2})", r"_{i \in \mathopen{}\left\{ 1, 2 \mathclose{}\right\}}^{}" r" \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in f(x))", r"_{i \in f \mathopen{}\left( x \mathclose{}\right)}^{}" r" \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n))", r"_{i = 0}^{n - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n + 1))", r"_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n + 2))", r"_{i = 0}^{n + 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( # ast.parse() does not recognize negative integers. "(i for i in range(n - -1))", r"_{i = 0}^{n - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n - 1))", r"_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n + m))", r"_{i = 0}^{n + m - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n - m))", r"_{i = 0}^{n - m - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(3))", r"_{i = 0}^{2} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(3 + 1))", r"_{i = 0}^{3} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(3 + 2))", r"_{i = 0}^{3 + 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(3 - 1))", r"_{i = 0}^{3 - 2} \mathopen{}\left({i}\mathclose{}\right)", ), ( # ast.parse() does not recognize negative integers. "(i for i in range(3 - -1))", r"_{i = 0}^{3 - -1 - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(3 + m))", r"_{i = 0}^{3 + m - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(3 - m))", r"_{i = 0}^{3 - m - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n, m))", r"_{i = n}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(1, m))", r"_{i = 1}^{m - 1} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n, 3))", r"_{i = n}^{2} \mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in range(n, m, k))", r"_{i \in \mathrm{range} \mathopen{}\left( n, m, k \mathclose{}\right)}^{}" r" \mathopen{}\left({i}\mathclose{}\right)", ), ], ) def test_visit_call_sum_prod(src_suffix: str, dest_suffix: str) -> None: for src_fn, dest_fn in [("fsum", r"\sum"), ("sum", r"\sum"), ("prod", r"\prod")]: node = ast_utils.parse_expr(src_fn + src_suffix) assert isinstance(node, ast.Call) assert ( expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix ) @pytest.mark.parametrize( "code,latex", [ # 2 clauses ( "sum(i for y in x for i in y)", r"\sum_{y \in x}^{} \sum_{i \in y}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), ( "sum(i for y in x for z in y for i in z)", r"\sum_{y \in x}^{} \sum_{z \in y}^{} \sum_{i \in z}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), # 3 clauses ( "prod(i for y in x for i in y)", r"\prod_{y \in x}^{} \prod_{i \in y}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), ( "prod(i for y in x for z in y for i in z)", r"\prod_{y \in x}^{} \prod_{z \in y}^{} \prod_{i \in z}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), # reduce stop parameter ( "sum(i for i in range(n+1))", r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", ), ( "prod(i for i in range(n-1))", r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", ), # reduce stop parameter ( "sum(i for i in range(n+1))", r"\sum_{i = 0}^{n} \mathopen{}\left({i}\mathclose{}\right)", ), ( "prod(i for i in range(n-1))", r"\prod_{i = 0}^{n - 2} \mathopen{}\left({i}\mathclose{}\right)", ), ], ) def test_visit_call_sum_prod_multiple_comprehension(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.Call) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "src_suffix,dest_suffix", [ ( "(i for i in x if i < y)", r"_{\mathopen{}\left( i \in x \mathclose{}\right) " r"\land \mathopen{}\left( i < y \mathclose{}\right)}^{} " r"\mathopen{}\left({i}\mathclose{}\right)", ), ( "(i for i in x if i < y if f(i))", r"_{\mathopen{}\left( i \in x \mathclose{}\right)" r" \land \mathopen{}\left( i < y \mathclose{}\right)" r" \land \mathopen{}\left( f \mathopen{}\left(" r" i \mathclose{}\right) \mathclose{}\right)}^{}" r" \mathopen{}\left({i}\mathclose{}\right)", ), ], ) def test_visit_call_sum_prod_with_if(src_suffix: str, dest_suffix: str) -> None: for src_fn, dest_fn in [("sum", r"\sum"), ("prod", r"\prod")]: node = ast_utils.parse_expr(src_fn + src_suffix) assert isinstance(node, ast.Call) assert ( expression_codegen.ExpressionCodegen().visit(node) == dest_fn + dest_suffix ) @pytest.mark.parametrize( "code,latex", [ ( "x if x < y else y", r"\left\{ \begin{array}{ll}" r" x, & \mathrm{if} \ x < y \\" r" y, & \mathrm{otherwise}" r" \end{array} \right.", ), ( "x if x < y else (y if y < z else z)", r"\left\{ \begin{array}{ll}" r" x, & \mathrm{if} \ x < y \\" r" y, & \mathrm{if} \ y < z \\" r" z, & \mathrm{otherwise}" r" \end{array} \right.", ), ( "x if x < y else (y if y < z else (z if z < w else w))", r"\left\{ \begin{array}{ll}" r" x, & \mathrm{if} \ x < y \\" r" y, & \mathrm{if} \ y < z \\" r" z, & \mathrm{if} \ z < w \\" r" w, & \mathrm{otherwise}" r" \end{array} \right.", ), ], ) def test_if_then_else(code: str, latex: str) -> None: node = ast_utils.parse_expr(code) assert isinstance(node, ast.IfExp) assert expression_codegen.ExpressionCodegen().visit(node) == latex @pytest.mark.parametrize( "code,latex", [ # x op y ("x**y", r"x^{y}"), ("x * y", r"x y"), ("x @ y", r"x y"), ("x / y", r"\frac{x}{y}"), ("x // y", r"\left\lfloor\frac{x}{y}\right\rfloor"), ("x % y", r"x \mathbin{\%} y"), ("x + y", r"x + y"), ("x - y", r"x - y"), ("x << y", r"x \ll y"), ("x >> y", r"x \gg y"), ("x & y", r"x \mathbin{\&} y"), ("x ^ y", r"x \oplus y"), ("x | y", R"x \mathbin{|} y"), # (x op y) op z ("(x**y)**z", r"\mathopen{}\left( x^{y} \mathclose{}\right)^{z}"), ("(x * y) * z", r"x y z"), ("(x @ y) @ z", r"x y z"), ("(x / y) / z", r"\frac{\frac{x}{y}}{z}"), ( "(x // y) // z", r"\left\lfloor\frac{\left\lfloor\frac{x}{y}\right\rfloor}{z}\right\rfloor", ), ("(x % y) % z", r"x \mathbin{\%} y \mathbin{\%} z"), ("(x + y) + z", r"x + y + z"), ("(x - y) - z", r"x - y - z"), ("(x << y) << z", r"x \ll y \ll z"), ("(x >> y) >> z", r"x \gg y \gg z"), ("(x & y) & z", r"x \mathbin{\&} y \mathbin{\&} z"), ("(x ^ y) ^ z", r"x \oplus y \oplus z"), ("(x | y) | z", r"x \mathbin{|} y \mathbin{|} z"), # x op (y op z) ("x**(y**z)", r"x^{y^{z}}"), ("x * (y * z)", r"x y z"), ("x @ (y @ z)", r"x y z"), ("x / (y / z)", r"\frac{x}{\frac{y}{z}}"), ( "x // (y // z)", r"\left\lfloor\frac{x}{\left\lfloor\frac{y}{z}\right\rfloor}\right\rfloor", ), ( "x % (y % z)", r"x \mathbin{\%} \mathopen{}\left( y \mathbin{\%} z \mathclose{}\right)", ), ("x + (y + z)", r"x + y + z"), ("x - (y - z)", r"x - \mathopen{}\left( y - z \mathclose{}\right)"), ("x << (y << z)", r"x \ll \mathopen{}\left( y \ll z \mathclose{}\right)"), ("x >> (y >> z)", r"x \gg \mathopen{}\left( y \gg z \mathclose{}\right)"), ("x & (y & z)", r"x \mathbin{\&} y \mathbin{\&} z"), ("x ^ (y ^ z)", r"x \oplus y \oplus z"), ("x | (y | z)", r"x \mathbin{|} y \mathbin{|} z"), # x OP y op z ("x**y * z", r"x^{y} z"), ("x * y + z", r"x y + z"), ("x @ y + z", r"x y + z"), ("x / y + z", r"\frac{x}{y} + z"), ("x // y + z", r"\left\lfloor\frac{x}{y}\right\rfloor + z"), ("x % y + z", r"x \mathbin{\%} y + z"), ("x + y << z", r"x + y \ll z"), ("x - y << z", r"x - y \ll z"), ("x << y & z", r"x \ll y \mathbin{\&} z"), ("x >> y & z", r"x \gg y \mathbin{\&} z"), ("x & y ^ z", r"x \mathbin{\&} y \oplus z"), ("x ^ y | z", r"x \oplus y \mathbin{|} z"), # x OP (y op z) ("x**(y * z)", r"x^{y z}"), ("x * (y + z)", r"x \cdot \mathopen{}\left( y + z \mathclose{}\right)"), ("x @ (y + z)", r"x \cdot \mathopen{}\left( y + z \mathclose{}\right)"), ("x / (y + z)", r"\frac{x}{y + z}"), ("x // (y + z)", r"\left\lfloor\frac{x}{y + z}\right\rfloor"), ("x % (y + z)", r"x \mathbin{\%} \mathopen{}\left( y + z \mathclose{}\right)"), ("x + (y << z)", r"x + \mathopen{}\left( y \ll z \mathclose{}\right)"), ("x - (y << z)", r"x - \mathopen{}\left( y \ll z \mathclose{}\right)"), ( "x << (y & z)", r"x \ll \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)", ), ( "x >> (y & z)", r"x \gg \mathopen{}\left( y \mathbin{\&} z \mathclose{}\right)", ), ( "x & (y ^ z)", r"x \mathbin{\&} \mathopen{}\left( y \oplus z \mathclose{}\right)", ), ( "x ^ (y | z)", r"x \oplus \mathopen{}\left( y \mathbin{|} z \mathclose{}\right)", ), # x op y OP z ("x * y**z", r"x y^{z}"), ("x + y * z", r"x + y z"), ("x + y @ z", r"x + y z"), ("x + y / z", r"x + \frac{y}{z}"), ("x + y // z", r"x + \left\lfloor\frac{y}{z}\right\rfloor"), ("x + y % z", r"x + y \mathbin{\%} z"), ("x << y + z", r"x \ll y + z"), ("x << y - z", r"x \ll y - z"), ("x & y << z", r"x \mathbin{\&} y \ll z"), ("x & y >> z", r"x \mathbin{\&} y \gg z"), ("x ^ y & z", r"x \oplus y \mathbin{\&} z"), ("x | y ^ z", r"x \mathbin{|} y \oplus z"), # (x op y) OP z ("(x * y)**z", r"\mathopen{}\left( x y \mathclose{}\right)^{z}"), ("(x + y) * z", r"\mathopen{}\left( x + y \mathclose{}\right) z"), ("(x + y) @ z", r"\mathopen{}\left( x + y \mathclose{}\right) z"), ("(x + y) / z", r"\frac{x + y}{z}"), ("(x + y) // z", r"\left\lfloor\frac{x + y}{z}\right\rfloor"), ("(x + y) % z", r"\mathopen{}\left( x + y \mathclose{}\right) \mathbin{\%} z"), ("(x << y) + z", r"\mathopen{}\left( x \ll y \mathclose{}\right) + z"), ("(x << y) - z", r"\mathopen{}\left( x \ll y \mathclose{}\right) - z"), ( "(x & y) << z", r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \ll z", ), ( "(x & y) >> z", r"\mathopen{}\left( x \mathbin{\&} y \mathclose{}\right) \gg z", ), ( "(x ^ y) & z", r"\mathopen{}\left( x \oplus y \mathclose{}\right) \mathbin{\&} z", ), ( "(x | y) ^ z", r"\mathopen{}\left( x \mathbin{|} y \mathclose{}\right) \oplus z", ), # is_wrapped ("(x // y)**z", r"\left\lfloor\frac{x}{y}\right\rfloor^{z}"), # With Call ("x**f(y)", r"x^{f \mathopen{}\left( y \mathclose{}\right)}"), ( "f(x)**y", r"\mathopen{}\left(" r" f \mathopen{}\left( x \mathclose{}\right)" r" \mathclose{}\right)^{y}", ), ("x * f(y)", r"x \cdot f \mathopen{}\left( y \mathclose{}\right)"), ("f(x) * y", r"f \mathopen{}\left( x \mathclose{}\right) \cdot y"), ("x / f(y)", r"\frac{x}{f \mathopen{}\left( y \mathclose{}\right)}"), ("f(x) / y", r"\frac{f \mathopen{}\left( x \mathclose{}\right)}{y}"), ("x + f(y)", r"x + f \mathopen{}\left( y \mathclose{}\right)"), ("f(x) + y", r"f \mathopen{}\left( x \mathclose{}\right) + y"), # With is_wrapped Call ("sqrt(x) ** y", r"\sqrt{ x }^{y}"), # With UnaryOp ("x**-y", r"x^{-y}"), ("(-x)**y", r"\mathopen{}\left( -x \mathclose{}\right)^{y}"), ("x * -y", r"x \cdot -y"), ("-x * y", r"-x y"), ("x / -y", r"\frac{x}{-y}"), ("-x / y", r"\frac{-x}{y}"), ("x + -y", r"x + -y"), ("-x + y", r"-x + y"), # With Compare ("x**(y == z)", r"x^{y = z}"), ("(x == y)**z", r"\mathopen{}\left( x = y \mathclose{}\right)^{z}"), ("x * (y == z)", r"x \cdot \mathopen{}\left( y = z \mathclose{}\right)"), ("(x == y) * z", r"\mathopen{}\left( x = y \mathclose{}\right) z"), ("x / (y == z)", r"\frac{x}{y = z}"), ("(x == y) / z", r"\frac{x = y}{z}"), ("x + (y == z)", r"x + \mathopen{}\left( y = z \mathclose{}\right)"), ("(x == y) + z", r"\mathopen{}\left( x = y \mathclose{}\right) + z"), # With BoolOp ("x**(y and z)", r"x^{y \land z}"), ("(x and y)**z", r"\mathopen{}\left( x \land y \mathclose{}\right)^{z}"), ("x * (y and z)", r"x \cdot \mathopen{}\left( y \land z \mathclose{}\right)"), ("(x and y) * z", r"\mathopen{}\left( x \land y \mathclose{}\right) z"), ("x / (y and z)", r"\frac{x}{y \land z}"), ("(x and y) / z", r"\frac{x \land y}{z}"), ("x + (y and z)", r"x + \mathopen{}\left( y \land z \mathclose{}\right)"), ("(x and y) + z", r"\mathopen{}\left( x \land y \mathclose{}\right) + z"), ], ) def test_visit_binop(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.BinOp) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ # With literals ("+x", r"+x"), ("-x", r"-x"), ("~x", r"\mathord{\sim} x"), ("not x", r"\lnot x"), # With Call ("+f(x)", r"+f \mathopen{}\left( x \mathclose{}\right)"), ("-f(x)", r"-f \mathopen{}\left( x \mathclose{}\right)"), ("~f(x)", r"\mathord{\sim} f \mathopen{}\left( x \mathclose{}\right)"), ("not f(x)", r"\lnot f \mathopen{}\left( x \mathclose{}\right)"), # With BinOp ("+(x + y)", r"+\mathopen{}\left( x + y \mathclose{}\right)"), ("-(x + y)", r"-\mathopen{}\left( x + y \mathclose{}\right)"), ("~(x + y)", r"\mathord{\sim} \mathopen{}\left( x + y \mathclose{}\right)"), ("not x + y", r"\lnot \mathopen{}\left( x + y \mathclose{}\right)"), # With Compare ("+(x == y)", r"+\mathopen{}\left( x = y \mathclose{}\right)"), ("-(x == y)", r"-\mathopen{}\left( x = y \mathclose{}\right)"), ("~(x == y)", r"\mathord{\sim} \mathopen{}\left( x = y \mathclose{}\right)"), ("not x == y", r"\lnot \mathopen{}\left( x = y \mathclose{}\right)"), # With BoolOp ("+(x and y)", r"+\mathopen{}\left( x \land y \mathclose{}\right)"), ("-(x and y)", r"-\mathopen{}\left( x \land y \mathclose{}\right)"), ( "~(x and y)", r"\mathord{\sim} \mathopen{}\left( x \land y \mathclose{}\right)", ), ("not (x and y)", r"\lnot \mathopen{}\left( x \land y \mathclose{}\right)"), ], ) def test_visit_unaryop(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.UnaryOp) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ # 1 comparator ("a == b", "a = b"), ("a > b", "a > b"), ("a >= b", r"a \ge b"), ("a in b", r"a \in b"), ("a is b", r"a \equiv b"), ("a is not b", r"a \not\equiv b"), ("a < b", "a < b"), ("a <= b", r"a \le b"), ("a != b", r"a \ne b"), ("a not in b", r"a \notin b"), # 2 comparators ("a == b == c", "a = b = c"), ("a == b > c", "a = b > c"), ("a == b >= c", r"a = b \ge c"), ("a == b < c", "a = b < c"), ("a == b <= c", r"a = b \le c"), ("a > b == c", "a > b = c"), ("a > b > c", "a > b > c"), ("a > b >= c", r"a > b \ge c"), ("a >= b == c", r"a \ge b = c"), ("a >= b > c", r"a \ge b > c"), ("a >= b >= c", r"a \ge b \ge c"), ("a < b == c", "a < b = c"), ("a < b < c", "a < b < c"), ("a < b <= c", r"a < b \le c"), ("a <= b == c", r"a \le b = c"), ("a <= b < c", r"a \le b < c"), ("a <= b <= c", r"a \le b \le c"), # With Call ("a == f(b)", r"a = f \mathopen{}\left( b \mathclose{}\right)"), ("f(a) == b", r"f \mathopen{}\left( a \mathclose{}\right) = b"), # With BinOp ("a == b + c", r"a = b + c"), ("a + b == c", r"a + b = c"), # With UnaryOp ("a == -b", r"a = -b"), ("-a == b", r"-a = b"), ("a == (not b)", r"a = \lnot b"), ("(not a) == b", r"\lnot a = b"), # With BoolOp ("a == (b and c)", r"a = \mathopen{}\left( b \land c \mathclose{}\right)"), ("(a and b) == c", r"\mathopen{}\left( a \land b \mathclose{}\right) = c"), ], ) def test_visit_compare(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Compare) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ # With literals ("a and b", r"a \land b"), ("a and b and c", r"a \land b \land c"), ("a or b", r"a \lor b"), ("a or b or c", r"a \lor b \lor c"), ("a or b and c", r"a \lor b \land c"), ( "(a or b) and c", r"\mathopen{}\left( a \lor b \mathclose{}\right) \land c", ), ("a and b or c", r"a \land b \lor c"), ( "a and (b or c)", r"a \land \mathopen{}\left( b \lor c \mathclose{}\right)", ), # With Call ("a and f(b)", r"a \land f \mathopen{}\left( b \mathclose{}\right)"), ("f(a) and b", r"f \mathopen{}\left( a \mathclose{}\right) \land b"), ("a or f(b)", r"a \lor f \mathopen{}\left( b \mathclose{}\right)"), ("f(a) or b", r"f \mathopen{}\left( a \mathclose{}\right) \lor b"), # With BinOp ("a and b + c", r"a \land b + c"), ("a + b and c", r"a + b \land c"), ("a or b + c", r"a \lor b + c"), ("a + b or c", r"a + b \lor c"), # With UnaryOp ("a and not b", r"a \land \lnot b"), ("not a and b", r"\lnot a \land b"), ("a or not b", r"a \lor \lnot b"), ("not a or b", r"\lnot a \lor b"), # With Compare ("a and b == c", r"a \land b = c"), ("a == b and c", r"a = b \land c"), ("a or b == c", r"a \lor b = c"), ("a == b or c", r"a = b \lor c"), ], ) def test_visit_boolop(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.BoolOp) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("0", "0"), ("1", "1"), ("0.0", "0.0"), ("1.5", "1.5"), ("0.0j", "0j"), ("1.0j", "1j"), ("1.5j", "1.5j"), ('"abc"', r'\textrm{"abc"}'), ('b"abc"', r"\textrm{b'abc'}"), ("None", r"\mathrm{None}"), ("False", r"\mathrm{False}"), ("True", r"\mathrm{True}"), ("...", r"\cdots"), ], ) def test_visit_constant(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Constant) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("x[0]", "x_{0}"), ("x[0][1]", "x_{0, 1}"), ("x[0][1][2]", "x_{0, 1, 2}"), ("x[foo]", r"x_{\mathrm{foo}}"), ("x[floor(x)]", r"x_{\mathopen{}\left\lfloor x \mathclose{}\right\rfloor}"), ], ) def test_visit_subscript(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Subscript) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("a - b", r"a \setminus b"), ("a & b", r"a \cap b"), ("a ^ b", r"a \mathbin{\triangle} b"), ("a | b", r"a \cup b"), ], ) def test_visit_binop_use_set_symbols(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.BinOp) assert ( expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex ) @pytest.mark.parametrize( "code,latex", [ ("a < b", r"a \subset b"), ("a <= b", r"a \subseteq b"), ("a > b", r"a \supset b"), ("a >= b", r"a \supseteq b"), ], ) def test_visit_compare_use_set_symbols(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Compare) assert ( expression_codegen.ExpressionCodegen(use_set_symbols=True).visit(tree) == latex ) @pytest.mark.parametrize( "code,latex", [ ("array(1)", r"\mathrm{array} \mathopen{}\left( 1 \mathclose{}\right)"), ( "array([])", r"\mathrm{array} \mathopen{}\left(" r" \mathopen{}\left[ \mathclose{}\right]" r" \mathclose{}\right)", ), ("array([1])", r"\begin{bmatrix} 1 \end{bmatrix}"), ("array([1, 2, 3])", r"\begin{bmatrix} 1 & 2 & 3 \end{bmatrix}"), ( "array([[]])", r"\mathrm{array} \mathopen{}\left(" r" \mathopen{}\left[ \mathopen{}\left[" r" \mathclose{}\right] \mathclose{}\right]" r" \mathclose{}\right)", ), ("array([[1]])", r"\begin{bmatrix} 1 \end{bmatrix}"), ("array([[1], [2], [3]])", r"\begin{bmatrix} 1 \\ 2 \\ 3 \end{bmatrix}"), ( "array([[1], [2], [3, 4]])", r"\mathrm{array} \mathopen{}\left(" r" \mathopen{}\left[" r" \mathopen{}\left[ 1 \mathclose{}\right]," r" \mathopen{}\left[ 2 \mathclose{}\right]," r" \mathopen{}\left[ 3, 4 \mathclose{}\right]" r" \mathclose{}\right]" r" \mathclose{}\right)", ), ( "array([[1, 2], [3, 4], [5, 6]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \\ 5 & 6 \end{bmatrix}", ), # Only checks two cases for ndarray. ("ndarray(1)", r"\mathrm{ndarray} \mathopen{}\left( 1 \mathclose{}\right)"), ("ndarray([1])", r"\begin{bmatrix} 1 \end{bmatrix}"), ], ) def test_numpy_array(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("zeros(0)", r"\mathbf{0}^{1 \times 0}"), ("zeros(1)", r"\mathbf{0}^{1 \times 1}"), ("zeros(2)", r"\mathbf{0}^{1 \times 2}"), ("zeros(())", r"0"), ("zeros((0,))", r"\mathbf{0}^{1 \times 0}"), ("zeros((1,))", r"\mathbf{0}^{1 \times 1}"), ("zeros((2,))", r"\mathbf{0}^{1 \times 2}"), ("zeros((0, 0))", r"\mathbf{0}^{0 \times 0}"), ("zeros((1, 1))", r"\mathbf{0}^{1 \times 1}"), ("zeros((2, 3))", r"\mathbf{0}^{2 \times 3}"), ("zeros((0, 0, 0))", r"\mathbf{0}^{0 \times 0 \times 0}"), ("zeros((1, 1, 1))", r"\mathbf{0}^{1 \times 1 \times 1}"), ("zeros((2, 3, 5))", r"\mathbf{0}^{2 \times 3 \times 5}"), # Unsupported ("zeros()", r"\mathrm{zeros} \mathopen{}\left( \mathclose{}\right)"), ("zeros(x)", r"\mathrm{zeros} \mathopen{}\left( x \mathclose{}\right)"), ("zeros(0, x)", r"\mathrm{zeros} \mathopen{}\left( 0, x \mathclose{}\right)"), ( "zeros((x,))", r"\mathrm{zeros} \mathopen{}\left(" r" \mathopen{}\left( x \mathclose{}\right)" r" \mathclose{}\right)", ), ], ) def test_zeros(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("identity(0)", r"\mathbf{I}_{0}"), ("identity(1)", r"\mathbf{I}_{1}"), ("identity(2)", r"\mathbf{I}_{2}"), # Unsupported ("identity()", r"\mathrm{identity} \mathopen{}\left( \mathclose{}\right)"), ("identity(x)", r"\mathrm{identity} \mathopen{}\left( x \mathclose{}\right)"), ( "identity(0, x)", r"\mathrm{identity} \mathopen{}\left( 0, x \mathclose{}\right)", ), ], ) def test_identity(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("transpose(A)", r"\mathbf{A}^\intercal"), ("transpose(b)", r"\mathbf{b}^\intercal"), # Unsupported ("transpose()", r"\mathrm{transpose} \mathopen{}\left( \mathclose{}\right)"), ("transpose(2)", r"\mathrm{transpose} \mathopen{}\left( 2 \mathclose{}\right)"), ( "transpose(a, (1, 0))", r"\mathrm{transpose} \mathopen{}\left( a, " r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", ), ], ) def test_transpose(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("det(A)", r"\det \mathopen{}\left( \mathbf{A} \mathclose{}\right)"), ("det(b)", r"\det \mathopen{}\left( \mathbf{b} \mathclose{}\right)"), ( "det([[1, 2], [3, 4]])", r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 \\" r" 3 & 4 \end{bmatrix} \mathclose{}\right)", ), ( "det([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", r"\det \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\" r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)", ), # Unsupported ("det()", r"\mathrm{det} \mathopen{}\left( \mathclose{}\right)"), ("det(2)", r"\mathrm{det} \mathopen{}\left( 2 \mathclose{}\right)"), ( "det(a, (1, 0))", r"\mathrm{det} \mathopen{}\left( a, " r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", ), ], ) def test_determinant(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ( "matrix_rank(A)", r"\mathrm{rank} \mathopen{}\left( \mathbf{A} \mathclose{}\right)", ), ( "matrix_rank(b)", r"\mathrm{rank} \mathopen{}\left( \mathbf{b} \mathclose{}\right)", ), ( "matrix_rank([[1, 2], [3, 4]])", r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 \\" r" 3 & 4 \end{bmatrix} \mathclose{}\right)", ), ( "matrix_rank([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", r"\mathrm{rank} \mathopen{}\left( \begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\" r" 7 & 8 & 9 \end{bmatrix} \mathclose{}\right)", ), # Unsupported ( "matrix_rank()", r"\mathrm{matrix\_rank} \mathopen{}\left( \mathclose{}\right)", ), ( "matrix_rank(2)", r"\mathrm{matrix\_rank} \mathopen{}\left( 2 \mathclose{}\right)", ), ( "matrix_rank(a, (1, 0))", r"\mathrm{matrix\_rank} \mathopen{}\left( a, " r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", ), ], ) def test_matrix_rank(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("matrix_power(A, 2)", r"\mathbf{A}^{2}"), ("matrix_power(b, 2)", r"\mathbf{b}^{2}"), ( "matrix_power([[1, 2], [3, 4]], 2)", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{2}", ), ( "matrix_power([[1, 2, 3], [4, 5, 6], [7, 8, 9]], 42)", r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{42}", ), # Unsupported ( "matrix_power()", r"\mathrm{matrix\_power} \mathopen{}\left( \mathclose{}\right)", ), ( "matrix_power(2)", r"\mathrm{matrix\_power} \mathopen{}\left( 2 \mathclose{}\right)", ), ( "matrix_power(a, (1, 0))", r"\mathrm{matrix\_power} \mathopen{}\left( a, " r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", ), ], ) def test_matrix_power(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("inv(A)", r"\mathbf{A}^{-1}"), ("inv(b)", r"\mathbf{b}^{-1}"), ("inv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{-1}"), ( "inv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{-1}", ), # Unsupported ("inv()", r"\mathrm{inv} \mathopen{}\left( \mathclose{}\right)"), ("inv(2)", r"\mathrm{inv} \mathopen{}\left( 2 \mathclose{}\right)"), ( "inv(a, (1, 0))", r"\mathrm{inv} \mathopen{}\left( a, " r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", ), ], ) def test_inv(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex @pytest.mark.parametrize( "code,latex", [ ("pinv(A)", r"\mathbf{A}^{+}"), ("pinv(b)", r"\mathbf{b}^{+}"), ("pinv([[1, 2], [3, 4]])", r"\begin{bmatrix} 1 & 2 \\ 3 & 4 \end{bmatrix}^{+}"), ( "pinv([[1, 2, 3], [4, 5, 6], [7, 8, 9]])", r"\begin{bmatrix} 1 & 2 & 3 \\ 4 & 5 & 6 \\ 7 & 8 & 9 \end{bmatrix}^{+}", ), # Unsupported ("pinv()", r"\mathrm{pinv} \mathopen{}\left( \mathclose{}\right)"), ("pinv(2)", r"\mathrm{pinv} \mathopen{}\left( 2 \mathclose{}\right)"), ( "pinv(a, (1, 0))", r"\mathrm{pinv} \mathopen{}\left( a, " r"\mathopen{}\left( 1, 0 \mathclose{}\right) \mathclose{}\right)", ), ], ) def test_pinv(code: str, latex: str) -> None: tree = ast_utils.parse_expr(code) assert isinstance(tree, ast.Call) assert expression_codegen.ExpressionCodegen().visit(tree) == latex # Check list for #89. # https://github.com/google/latexify_py/issues/89#issuecomment-1344967636 @pytest.mark.parametrize( "left,right,latex", [ ("2", "3", r"2 \cdot 3"), ("2", "y", "2 y"), ("2", "beta", r"2 \beta"), ("2", "bar", r"2 \mathrm{bar}"), ("2", "g(y)", r"2 g \mathopen{}\left( y \mathclose{}\right)"), ("2", "(u + v)", r"2 \mathopen{}\left( u + v \mathclose{}\right)"), ("x", "3", r"x \cdot 3"), ("x", "y", "x y"), ("x", "beta", r"x \beta"), ("x", "bar", r"x \cdot \mathrm{bar}"), ("x", "g(y)", r"x \cdot g \mathopen{}\left( y \mathclose{}\right)"), ("x", "(u + v)", r"x \cdot \mathopen{}\left( u + v \mathclose{}\right)"), ("alpha", "3", r"\alpha \cdot 3"), ("alpha", "y", r"\alpha y"), ("alpha", "beta", r"\alpha \beta"), ("alpha", "bar", r"\alpha \cdot \mathrm{bar}"), ("alpha", "g(y)", r"\alpha \cdot g \mathopen{}\left( y \mathclose{}\right)"), ( "alpha", "(u + v)", r"\alpha \cdot \mathopen{}\left( u + v \mathclose{}\right)", ), ("foo", "3", r"\mathrm{foo} \cdot 3"), ("foo", "y", r"\mathrm{foo} \cdot y"), ("foo", "beta", r"\mathrm{foo} \cdot \beta"), ("foo", "bar", r"\mathrm{foo} \cdot \mathrm{bar}"), ( "foo", "g(y)", r"\mathrm{foo} \cdot g \mathopen{}\left( y \mathclose{}\right)", ), ( "foo", "(u + v)", r"\mathrm{foo} \cdot \mathopen{}\left( u + v \mathclose{}\right)", ), ("f(x)", "3", r"f \mathopen{}\left( x \mathclose{}\right) \cdot 3"), ("f(x)", "y", r"f \mathopen{}\left( x \mathclose{}\right) \cdot y"), ("f(x)", "beta", r"f \mathopen{}\left( x \mathclose{}\right) \cdot \beta"), ( "f(x)", "bar", r"f \mathopen{}\left( x \mathclose{}\right) \cdot \mathrm{bar}", ), ( "f(x)", "g(y)", r"f \mathopen{}\left( x \mathclose{}\right)" r" \cdot g \mathopen{}\left( y \mathclose{}\right)", ), ( "f(x)", "(u + v)", r"f \mathopen{}\left( x \mathclose{}\right)" r" \cdot \mathopen{}\left( u + v \mathclose{}\right)", ), ("(s + t)", "3", r"\mathopen{}\left( s + t \mathclose{}\right) \cdot 3"), ("(s + t)", "y", r"\mathopen{}\left( s + t \mathclose{}\right) y"), ("(s + t)", "beta", r"\mathopen{}\left( s + t \mathclose{}\right) \beta"), ( "(s + t)", "bar", r"\mathopen{}\left( s + t \mathclose{}\right) \mathrm{bar}", ), ( "(s + t)", "g(y)", r"\mathopen{}\left( s + t \mathclose{}\right)" r" g \mathopen{}\left( y \mathclose{}\right)", ), ( "(s + t)", "(u + v)", r"\mathopen{}\left( s + t \mathclose{}\right)" r" \mathopen{}\left( u + v \mathclose{}\right)", ), ], ) def test_remove_multiply(left: str, right: str, latex: str) -> None: for op in ["*", "@"]: tree = ast_utils.parse_expr(f"{left} {op} {right}") assert isinstance(tree, ast.BinOp) assert ( expression_codegen.ExpressionCodegen(use_math_symbols=True).visit(tree) == latex ) ================================================ FILE: src/latexify/codegen/expression_rules.py ================================================ """Codegen rules for single expressions.""" from __future__ import annotations import ast import dataclasses # Precedences of operators for BoolOp, BinOp, UnaryOp, and Compare nodes. # Note that this value affects only the appearance of surrounding parentheses for each # expression, and does not affect the AST itself. # See also: # https://docs.python.org/3/reference/expressions.html#operator-precedence _PRECEDENCES: dict[type[ast.AST], int] = { ast.Pow: 120, ast.UAdd: 110, ast.USub: 110, ast.Invert: 110, ast.Mult: 100, ast.MatMult: 100, ast.Div: 100, ast.FloorDiv: 100, ast.Mod: 100, ast.Add: 90, ast.Sub: 90, ast.LShift: 80, ast.RShift: 80, ast.BitAnd: 70, ast.BitXor: 60, ast.BitOr: 50, ast.In: 40, ast.NotIn: 40, ast.Is: 40, ast.IsNot: 40, ast.Lt: 40, ast.LtE: 40, ast.Gt: 40, ast.GtE: 40, ast.NotEq: 40, ast.Eq: 40, # NOTE(odashi): # We assume that the `not` operator has the same precedence with other unary # operators `+`, `-` and `~`, because the LaTeX counterpart $\lnot$ looks to have a # high precedence. # ast.Not: 30, ast.Not: 110, ast.And: 20, ast.Or: 10, } # NOTE(odashi): # Function invocation is treated as a unary operator with a higher precedence. # This ensures that the argument with a unary operator is wrapped: # exp(x) --> \exp x # exp(-x) --> \exp (-x) # -exp(x) --> - \exp x _CALL_PRECEDENCE = _PRECEDENCES[ast.UAdd] + 1 _INF_PRECEDENCE = 1_000_000 def get_precedence(node: ast.AST) -> int: """Obtains the precedence of the subtree. Args: node: Subtree to investigate. Returns: If `node` is a subtree with some operator, returns the precedence of the operator. Otherwise, returns a number larger enough from other precedences. """ if isinstance(node, ast.Call): return _CALL_PRECEDENCE if isinstance(node, (ast.BinOp, ast.UnaryOp, ast.BoolOp)): return _PRECEDENCES[type(node.op)] if isinstance(node, ast.Compare): # Compare operators have the same precedence. It is enough to check only the # first operator. return _PRECEDENCES[type(node.ops[0])] return _INF_PRECEDENCE @dataclasses.dataclass(frozen=True) class BinOperandRule: """Syntax rules for operands of BinOp.""" # Whether to require wrapping operands by parentheses according to the precedence. wrap: bool = True # Whether to require wrapping operands by parentheses if the operand has the same # precedence with this operator. # This is used to control the behavior of non-associative operators. force: bool = False @dataclasses.dataclass(frozen=True) class BinOpRule: """Syntax rules for BinOp.""" # Left/middle/right syntaxes to wrap operands. latex_left: str latex_middle: str latex_right: str # Operand rules. operand_left: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) operand_right: BinOperandRule = dataclasses.field(default_factory=BinOperandRule) # Whether to assume the resulting syntax is wrapped by some bracket operators. # If True, the parent operator can avoid wrapping this operator by parentheses. is_wrapped: bool = False BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { ast.Pow: BinOpRule( "", "^{", "}", operand_left=BinOperandRule(force=True), operand_right=BinOperandRule(wrap=False), ), ast.Mult: BinOpRule("", r" \cdot ", ""), ast.MatMult: BinOpRule("", r" \cdot ", ""), ast.Div: BinOpRule( r"\frac{", "}{", "}", operand_left=BinOperandRule(wrap=False), operand_right=BinOperandRule(wrap=False), ), ast.FloorDiv: BinOpRule( r"\left\lfloor\frac{", "}{", r"}\right\rfloor", operand_left=BinOperandRule(wrap=False), operand_right=BinOperandRule(wrap=False), is_wrapped=True, ), ast.Mod: BinOpRule( "", r" \mathbin{\%} ", "", operand_right=BinOperandRule(force=True) ), ast.Add: BinOpRule("", " + ", ""), ast.Sub: BinOpRule("", " - ", "", operand_right=BinOperandRule(force=True)), ast.LShift: BinOpRule("", r" \ll ", "", operand_right=BinOperandRule(force=True)), ast.RShift: BinOpRule("", r" \gg ", "", operand_right=BinOperandRule(force=True)), ast.BitAnd: BinOpRule("", r" \mathbin{\&} ", ""), ast.BitXor: BinOpRule("", r" \oplus ", ""), ast.BitOr: BinOpRule("", r" \mathbin{|} ", ""), } # Typeset for BinOp of sets. SET_BIN_OP_RULES: dict[type[ast.operator], BinOpRule] = { **BIN_OP_RULES, ast.Sub: BinOpRule( "", r" \setminus ", "", operand_right=BinOperandRule(force=True) ), ast.BitAnd: BinOpRule("", r" \cap ", ""), ast.BitXor: BinOpRule("", r" \mathbin{\triangle} ", ""), ast.BitOr: BinOpRule("", r" \cup ", ""), } UNARY_OPS: dict[type[ast.unaryop], str] = { ast.Invert: r"\mathord{\sim} ", ast.UAdd: "+", # Explicitly adds the $+$ operator. ast.USub: "-", ast.Not: r"\lnot ", } COMPARE_OPS: dict[type[ast.cmpop], str] = { ast.Eq: "=", ast.Gt: ">", ast.GtE: r"\ge", ast.In: r"\in", ast.Is: r"\equiv", ast.IsNot: r"\not\equiv", ast.Lt: "<", ast.LtE: r"\le", ast.NotEq: r"\ne", ast.NotIn: r"\notin", } # Typeset for Compare of sets. SET_COMPARE_OPS: dict[type[ast.cmpop], str] = { **COMPARE_OPS, ast.Gt: r"\supset", ast.GtE: r"\supseteq", ast.Lt: r"\subset", ast.LtE: r"\subseteq", } BOOL_OPS: dict[type[ast.boolop], str] = { ast.And: r"\land", ast.Or: r"\lor", } @dataclasses.dataclass(frozen=True) class FunctionRule: """Codegen rules for functions. Attributes: left: LaTeX expression concatenated to the left-hand side of the arguments. right: LaTeX expression concatenated to the right-hand side of the arguments. is_unary: Whether the function is treated as a unary operator or not. is_wrapped: Whether the resulting syntax is wrapped by brackets or not. """ left: str right: str = "" is_unary: bool = False is_wrapped: bool = False # name => left_syntax, right_syntax, is_wrapped BUILTIN_FUNCS: dict[str, FunctionRule] = { "abs": FunctionRule(r"\mathopen{}\left|", r"\mathclose{}\right|", is_wrapped=True), "acos": FunctionRule(r"\arccos", is_unary=True), "acosh": FunctionRule(r"\mathrm{arcosh}", is_unary=True), "arccos": FunctionRule(r"\arccos", is_unary=True), "arccot": FunctionRule(r"\mathrm{arccot}", is_unary=True), "arccsc": FunctionRule(r"\mathrm{arccsc}", is_unary=True), "arcosh": FunctionRule(r"\mathrm{arcosh}", is_unary=True), "arcoth": FunctionRule(r"\mathrm{arcoth}", is_unary=True), "arcsec": FunctionRule(r"\mathrm{arcsec}", is_unary=True), "arcsch": FunctionRule(r"\mathrm{arcsch}", is_unary=True), "arcsin": FunctionRule(r"\arcsin", is_unary=True), "arctan": FunctionRule(r"\arctan", is_unary=True), "arsech": FunctionRule(r"\mathrm{arsech}", is_unary=True), "arsinh": FunctionRule(r"\mathrm{arsinh}", is_unary=True), "artanh": FunctionRule(r"\mathrm{artanh}", is_unary=True), "asin": FunctionRule(r"\arcsin", is_unary=True), "asinh": FunctionRule(r"\mathrm{arsinh}", is_unary=True), "atan": FunctionRule(r"\arctan", is_unary=True), "atanh": FunctionRule(r"\mathrm{artanh}", is_unary=True), "ceil": FunctionRule( r"\mathopen{}\left\lceil", r"\mathclose{}\right\rceil", is_wrapped=True ), "cos": FunctionRule(r"\cos", is_unary=True), "cosh": FunctionRule(r"\cosh", is_unary=True), "cot": FunctionRule(r"\cot", is_unary=True), "coth": FunctionRule(r"\coth", is_unary=True), "csc": FunctionRule(r"\csc", is_unary=True), "csch": FunctionRule(r"\mathrm{csch}", is_unary=True), "exp": FunctionRule(r"\exp", is_unary=True), "fabs": FunctionRule(r"\mathopen{}\left|", r"\mathclose{}\right|", is_wrapped=True), "factorial": FunctionRule("", "!", is_unary=True), "floor": FunctionRule( r"\mathopen{}\left\lfloor", r"\mathclose{}\right\rfloor", is_wrapped=True ), "fsum": FunctionRule(r"\sum", is_unary=True), "gamma": FunctionRule(r"\Gamma"), "log": FunctionRule(r"\log", is_unary=True), "log10": FunctionRule(r"\log_{10}", is_unary=True), "log2": FunctionRule(r"\log_2", is_unary=True), "prod": FunctionRule(r"\prod", is_unary=True), "sec": FunctionRule(r"\sec", is_unary=True), "sech": FunctionRule(r"\mathrm{sech}", is_unary=True), "sin": FunctionRule(r"\sin", is_unary=True), "sinh": FunctionRule(r"\sinh", is_unary=True), "sqrt": FunctionRule(r"\sqrt{", "}", is_wrapped=True), "sum": FunctionRule(r"\sum", is_unary=True), "tan": FunctionRule(r"\tan", is_unary=True), "tanh": FunctionRule(r"\tanh", is_unary=True), } MATH_SYMBOLS = { "aleph", "alpha", "beta", "beth", "chi", "daleth", "delta", "digamma", "epsilon", "eta", "gamma", "gimel", "hbar", "infty", "iota", "kappa", "lambda", "mu", "nabla", "nu", "omega", "phi", "pi", "psi", "rho", "sigma", "tau", "theta", "upsilon", "varepsilon", "varkappa", "varphi", "varpi", "varrho", "varsigma", "vartheta", "xi", "zeta", "Delta", "Gamma", "Lambda", "Omega", "Phi", "Pi", "Psi", "Sigma", "Theta", "Upsilon", "Xi", } ================================================ FILE: src/latexify/codegen/expression_rules_test.py ================================================ """Tests for latexify.codegen.expression_rules.""" from __future__ import annotations import ast import pytest from latexify.codegen import expression_rules @pytest.mark.parametrize( "node,precedence", [ ( ast.Call(func=ast.Name(id="func", ctx=ast.Load()), args=[], keywords=[]), expression_rules._CALL_PRECEDENCE, ), ( ast.BinOp( left=ast.Name(id="left", ctx=ast.Load()), op=ast.Add(), right=ast.Name(id="right", ctx=ast.Load()), ), expression_rules._PRECEDENCES[ast.Add], ), ( ast.UnaryOp(op=ast.UAdd(), operand=ast.Name(id="operand", ctx=ast.Load())), expression_rules._PRECEDENCES[ast.UAdd], ), ( ast.BoolOp(op=ast.And(), values=[ast.Name(id="value", ctx=ast.Load())]), expression_rules._PRECEDENCES[ast.And], ), ( ast.Compare( left=ast.Name(id="left", ctx=ast.Load()), ops=[ast.Eq()], comparators=[ast.Name(id="right", ctx=ast.Load())], ), expression_rules._PRECEDENCES[ast.Eq], ), (ast.Name(id="name", ctx=ast.Load()), expression_rules._INF_PRECEDENCE), ( ast.Attribute( value=ast.Name(id="value", ctx=ast.Load()), attr="attr", ctx=ast.Load() ), expression_rules._INF_PRECEDENCE, ), ], ) def test_get_precedence(node: ast.AST, precedence: int) -> None: assert expression_rules.get_precedence(node) == precedence ================================================ FILE: src/latexify/codegen/function_codegen.py ================================================ """Codegen for single functions.""" from __future__ import annotations import ast import sys from latexify import ast_utils, exceptions from latexify.codegen import codegen_utils, expression_codegen, identifier_converter class FunctionCodegen(ast.NodeVisitor): """Codegen for single functions. This codegen works for Module with single FunctionDef node to generate a single LaTeX expression of the given function. """ _identifier_converter: identifier_converter.IdentifierConverter _use_signature: bool def __init__( self, *, use_math_symbols: bool = False, use_signature: bool = True, use_set_symbols: bool = False, escape_underscores: bool = True, ) -> None: """Initializer. Args: use_math_symbols: Whether to convert identifiers with a math symbol surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). use_signature: Whether to add the function signature before the expression or not. use_set_symbols: Whether to use set symbols or not. """ self._expression_codegen = expression_codegen.ExpressionCodegen( use_math_symbols=use_math_symbols, use_set_symbols=use_set_symbols, escape_underscores=escape_underscores, ) self._identifier_converter = identifier_converter.IdentifierConverter( use_math_symbols=use_math_symbols, escape_underscores=escape_underscores ) self._use_signature = use_signature def generic_visit(self, node: ast.AST) -> str: raise exceptions.LatexifyNotSupportedError( f"Unsupported AST: {type(node).__name__}" ) def visit_Module(self, node: ast.Module) -> str: """Visit a Module node.""" return self.visit(node.body[0]) def visit_FunctionDef(self, node: ast.FunctionDef) -> str: """Visit a FunctionDef node.""" # Function name name_str = self._identifier_converter.convert(node.name)[0] # Arguments arg_strs = [ self._identifier_converter.convert(arg.arg)[0] for arg in node.args.args ] body_strs: list[str] = [] # Assignment statements (if any): x = ... for child in node.body[:-1]: if isinstance(child, ast.Expr) and ast_utils.is_constant(child.value): continue if not isinstance(child, ast.Assign): raise exceptions.LatexifyNotSupportedError( "Codegen supports only Assign nodes in multiline functions, " f"but got: {type(child).__name__}" ) body_strs.append(self.visit(child)) return_stmt = node.body[-1] if sys.version_info.minor >= 10: if not isinstance(return_stmt, (ast.Return, ast.If, ast.Match)): raise exceptions.LatexifySyntaxError( f"Unsupported last statement: {type(return_stmt).__name__}" ) else: if not isinstance(return_stmt, (ast.Return, ast.If)): raise exceptions.LatexifySyntaxError( f"Unsupported last statement: {type(return_stmt).__name__}" ) # Function signature: f(x, ...) signature_str = name_str + "(" + ", ".join(arg_strs) + ")" # Function definition: f(x, ...) \triangleq ... return_str = self.visit(return_stmt) if self._use_signature: return_str = signature_str + " = " + return_str if not body_strs: # Only the definition. return return_str # Definition with several assignments. Wrap all statements with array. body_strs.append(return_str) return r"\begin{array}{l} " + r" \\ ".join(body_strs) + r" \end{array}" def visit_Assign(self, node: ast.Assign) -> str: """Visit an Assign node.""" operands: list[str] = [self._expression_codegen.visit(t) for t in node.targets] operands.append(self._expression_codegen.visit(node.value)) return " = ".join(operands) def visit_Return(self, node: ast.Return) -> str: """Visit a Return node.""" return ( self._expression_codegen.visit(node.value) if node.value is not None else codegen_utils.convert_constant(None) ) def visit_If(self, node: ast.If) -> str: """Visit an If node.""" latex = r"\left\{ \begin{array}{ll} " current_stmt: ast.stmt = node while isinstance(current_stmt, ast.If): if len(current_stmt.body) != 1 or len(current_stmt.orelse) != 1: raise exceptions.LatexifySyntaxError( "Multiple statements are not supported in If nodes." ) cond_latex = self._expression_codegen.visit(current_stmt.test) true_latex = self.visit(current_stmt.body[0]) latex += true_latex + r", & \mathrm{if} \ " + cond_latex + r" \\ " current_stmt = current_stmt.orelse[0] latex += self.visit(current_stmt) return latex + r", & \mathrm{otherwise} \end{array} \right." def visit_Match(self, node: ast.Match) -> str: """Visit a Match node""" if not ( len(node.cases) >= 2 and isinstance(node.cases[-1].pattern, ast.MatchAs) and node.cases[-1].pattern.name is None ): raise exceptions.LatexifySyntaxError( "Match statement must contain the wildcard." ) subject_latex = self._expression_codegen.visit(node.subject) case_latexes: list[str] = [] for i, case in enumerate(node.cases): if len(case.body) != 1 or not isinstance(case.body[0], ast.Return): raise exceptions.LatexifyNotSupportedError( "Match cases must contain exactly 1 return statement." ) if i < len(node.cases) - 1: body_latex = self.visit(case.body[0]) cond_latex = self.visit(case.pattern) case_latexes.append( body_latex + r", & \mathrm{if} \ " + subject_latex + cond_latex ) else: case_latexes.append( self.visit(node.cases[-1].body[0]) + r", & \mathrm{otherwise}" ) return ( r"\left\{ \begin{array}{ll} " + r" \\ ".join(case_latexes) + r" \end{array} \right." ) def visit_MatchValue(self, node: ast.MatchValue) -> str: """Visit a MatchValue node""" latex = self._expression_codegen.visit(node.value) return " = " + latex ================================================ FILE: src/latexify/codegen/function_codegen_match_test.py ================================================ """Tests for FunctionCodegen with match statements.""" from __future__ import annotations import ast import textwrap import pytest from latexify import exceptions, test_utils from latexify.codegen import function_codegen @test_utils.require_at_least(10) def test_functiondef_match() -> None: tree = ast.parse( textwrap.dedent( """ def f(x): match x: case 0: return 1 case _: return 3 * x """ ) ) expected = ( r"f(x) =" r" \left\{ \begin{array}{ll}" r" 1, & \mathrm{if} \ x = 0 \\" r" 3 x, & \mathrm{otherwise}" r" \end{array} \right." ) assert function_codegen.FunctionCodegen().visit(tree) == expected @test_utils.require_at_least(10) def test_matchvalue() -> None: tree = ast.parse( textwrap.dedent( """ match x: case 0: return 1 case _: return 2 """ ) ).body[0] expected = ( r"\left\{ \begin{array}{ll}" r" 1, & \mathrm{if} \ x = 0 \\" r" 2, & \mathrm{otherwise}" r" \end{array} \right." ) assert function_codegen.FunctionCodegen().visit(tree) == expected @test_utils.require_at_least(10) def test_multiple_matchvalue() -> None: tree = ast.parse( textwrap.dedent( """ match x: case 0: return 1 case 1: return 2 case _: return 3 """ ) ).body[0] expected = ( r"\left\{ \begin{array}{ll}" r" 1, & \mathrm{if} \ x = 0 \\" r" 2, & \mathrm{if} \ x = 1 \\" r" 3, & \mathrm{otherwise}" r" \end{array} \right." ) assert function_codegen.FunctionCodegen().visit(tree) == expected @test_utils.require_at_least(10) def test_single_matchvalue_no_wildcards() -> None: tree = ast.parse( textwrap.dedent( """ match x: case 0: return 1 """ ) ).body[0] with pytest.raises( exceptions.LatexifySyntaxError, match=r"^Match statement must contain the wildcard\.$", ): function_codegen.FunctionCodegen().visit(tree) @test_utils.require_at_least(10) def test_multiple_matchvalue_no_wildcards() -> None: tree = ast.parse( textwrap.dedent( """ match x: case 0: return 1 case 1: return 2 """ ) ).body[0] with pytest.raises( exceptions.LatexifySyntaxError, match=r"^Match statement must contain the wildcard\.$", ): function_codegen.FunctionCodegen().visit(tree) @test_utils.require_at_least(10) def test_matchas_nonempty() -> None: tree = ast.parse( textwrap.dedent( """ match x: case [x] as y: return 1 case _: return 2 """ ) ).body[0] with pytest.raises( exceptions.LatexifyNotSupportedError, match=r"^Unsupported AST: MatchAs$", ): function_codegen.FunctionCodegen().visit(tree) @test_utils.require_at_least(10) def test_matchvalue_no_return() -> None: tree = ast.parse( textwrap.dedent( """ match x: case 0: x = 5 case _: return 0 """ ) ).body[0] with pytest.raises( exceptions.LatexifyNotSupportedError, match=r"^Match cases must contain exactly 1 return statement\.$", ): function_codegen.FunctionCodegen().visit(tree) @test_utils.require_at_least(10) def test_matchvalue_mutliple_statements() -> None: tree = ast.parse( textwrap.dedent( """ match x: case 0: x = 5 return 1 case _: return 0 """ ) ).body[0] with pytest.raises( exceptions.LatexifyNotSupportedError, match=r"^Match cases must contain exactly 1 return statement\.$", ): function_codegen.FunctionCodegen().visit(tree) ================================================ FILE: src/latexify/codegen/function_codegen_test.py ================================================ """Tests for latexify.codegen.function_codegen.""" from __future__ import annotations import ast import textwrap import pytest from latexify import exceptions from latexify.codegen import function_codegen def test_generic_visit() -> None: class UnknownNode(ast.AST): pass with pytest.raises( exceptions.LatexifyNotSupportedError, match=r"^Unsupported AST: UnknownNode$", ): function_codegen.FunctionCodegen().visit(UnknownNode()) def test_visit_functiondef_use_signature() -> None: tree = ast.parse( textwrap.dedent( """ def f(x): return x """ ) ).body[0] assert isinstance(tree, ast.FunctionDef) latex_without_flag = "x" latex_with_flag = r"f(x) = x" assert function_codegen.FunctionCodegen().visit(tree) == latex_with_flag assert ( function_codegen.FunctionCodegen(use_signature=False).visit(tree) == latex_without_flag ) assert ( function_codegen.FunctionCodegen(use_signature=True).visit(tree) == latex_with_flag ) def test_visit_functiondef_ignore_docstring() -> None: tree = ast.parse( textwrap.dedent( """ def f(x): '''docstring''' return x """ ) ).body[0] assert isinstance(tree, ast.FunctionDef) latex = r"f(x) = x" assert function_codegen.FunctionCodegen().visit(tree) == latex def test_visit_functiondef_ignore_multiple_constants() -> None: tree = ast.parse( textwrap.dedent( """ def f(x): '''docstring''' 3 True return x """ ) ).body[0] assert isinstance(tree, ast.FunctionDef) latex = r"f(x) = x" assert function_codegen.FunctionCodegen().visit(tree) == latex ================================================ FILE: src/latexify/codegen/identifier_converter.py ================================================ """Utility to convert identifiers.""" from __future__ import annotations from latexify.codegen import expression_rules class IdentifierConverter: r"""Converts Python identifiers to appropriate LaTeX expression. This converter applies following rules: - `foo` --> `\foo`, if `use_math_symbols == True` and the given identifier matches a supported math symbol name. - `x` --> `x`, if the given identifier is exactly 1 character (except `_`) - `foo_bar` --> `\mathrm{foo\_bar}`, otherwise. """ _use_math_symbols: bool _use_mathrm: bool _escape_underscores: bool def __init__( self, *, use_math_symbols: bool, use_mathrm: bool = True, escape_underscores: bool = True, ) -> None: r"""Initializer. Args: use_math_symbols: Whether to convert identifiers with math symbol names to appropriate LaTeX command. use_mathrm: Whether to wrap the resulting expression by \mathrm, if applicable. escape_underscores: Whether to prefix any underscores in identifiers with '\\', disable to allow subscripts in generated latex. """ self._use_math_symbols = use_math_symbols self._use_mathrm = use_mathrm self._escape_underscores = escape_underscores def convert(self, name: str) -> tuple[str, bool]: """Converts Python identifier to LaTeX expression. Args: name: Identifier name. Returns: Tuple of following values: - latex: Corresponding LaTeX expression. - is_single_character: Whether `latex` can be treated as a single character or not. Raises: LatexifyError: Resulting latex is not valid. This most likely occurs where the symbol starts or ends with an underscore, and escape_underscores=False. """ if not self._escape_underscores and "_" in name: # Check if we are going to generate an invalid Latex string. Better to # raise an exception here than have the resulting Latex fail to # compile/display name_splits = name.split("_") if not all(name_splits): raise ValueError( "Neither preceding/trailing underscores nor double underscores is " f"allowed by the `escape_underscores` option, but got: {name}" ) elems = [ IdentifierConverter( use_math_symbols=self._use_math_symbols, use_mathrm=False, escape_underscores=True, ).convert(n)[0] for n in name_splits ] # Wrap sub identifiers in nested braces name = "_{".join(elems) + "}" * (len(elems) - 1) if self._use_math_symbols and name in expression_rules.MATH_SYMBOLS: return "\\" + name, True if len(name) == 1 and name != "_": return name, True escaped = name.replace("_", r"\_") if self._escape_underscores else name wrapped = rf"\mathrm{{{escaped}}}" if self._use_mathrm else escaped return wrapped, False ================================================ FILE: src/latexify/codegen/identifier_converter_test.py ================================================ """Tests for latexify.codegen.identifier_converter.""" from __future__ import annotations import pytest from latexify.codegen import identifier_converter @pytest.mark.parametrize( "name,use_math_symbols,use_mathrm,escape_underscores,expected", [ ("a", False, True, True, ("a", True)), ("_", False, True, True, (r"\mathrm{\_}", False)), ("aa", False, True, True, (r"\mathrm{aa}", False)), ("a1", False, True, True, (r"\mathrm{a1}", False)), ("a_", False, True, True, (r"\mathrm{a\_}", False)), ("_a", False, True, True, (r"\mathrm{\_a}", False)), ("_1", False, True, True, (r"\mathrm{\_1}", False)), ("__", False, True, True, (r"\mathrm{\_\_}", False)), ("a_a", False, True, True, (r"\mathrm{a\_a}", False)), ("a__", False, True, True, (r"\mathrm{a\_\_}", False)), ("a_1", False, True, True, (r"\mathrm{a\_1}", False)), ("alpha", False, True, True, (r"\mathrm{alpha}", False)), ("alpha", True, True, True, (r"\alpha", True)), ("alphabet", True, True, True, (r"\mathrm{alphabet}", False)), ("foo", False, True, True, (r"\mathrm{foo}", False)), ("foo", True, True, True, (r"\mathrm{foo}", False)), ("foo", True, False, True, (r"foo", False)), ("aa", False, True, False, (r"\mathrm{aa}", False)), ("a_a", False, True, False, (r"\mathrm{a_{a}}", False)), ("a_1", False, True, False, (r"\mathrm{a_{1}}", False)), ("alpha", True, False, False, (r"\alpha", True)), ("alpha_1", True, False, False, (r"\alpha_{1}", False)), ("x_alpha", True, False, False, (r"x_{\alpha}", False)), ("x_alpha_beta", True, False, False, (r"x_{\alpha_{\beta}}", False)), ("alpha_beta", True, False, False, (r"\alpha_{\beta}", False)), ], ) def test_identifier_converter( name: str, use_math_symbols: bool, use_mathrm: bool, escape_underscores: bool, expected: tuple[str, bool], ) -> None: assert ( identifier_converter.IdentifierConverter( use_math_symbols=use_math_symbols, use_mathrm=use_mathrm, escape_underscores=escape_underscores, ).convert(name) == expected ) @pytest.mark.parametrize( "name,use_math_symbols,use_mathrm,escape_underscores", [ ("_", False, True, False), ("a_", False, True, False), ("_a", False, True, False), ("_1", False, True, False), ("__", False, True, False), ("a__", False, True, False), ("alpha_", True, False, False), ("_alpha", True, False, False), ("x__alpha", True, False, False), ], ) def test_identifier_converter_failure( name: str, use_math_symbols: bool, use_mathrm: bool, escape_underscores: bool, ) -> None: with pytest.raises(ValueError): identifier_converter.IdentifierConverter( use_math_symbols=use_math_symbols, use_mathrm=use_mathrm, escape_underscores=escape_underscores, ).convert(name) ================================================ FILE: src/latexify/codegen/latex.py ================================================ """Definition of Latex.""" from __future__ import annotations from collections.abc import Iterable from typing import Union LatexLike = Union[str, "Latex"] class Latex: """LaTeX expression string for ease of writing the codegen source.""" _raw: str def __init__(self, raw: str) -> None: """Initializer. Args: raw: Direct string of the underlying expression. """ self._raw = raw def __eq__(self, other: object) -> bool: """Checks equality. Args: other: Other object to check equality. Returns: True if other is Latex and the underlying expression is the same as self, False otherwise. """ return isinstance(other, Latex) and other._raw == self._raw def __str__(self) -> str: """Returns the underlying expression. Returns: The underlying expression. """ return self._raw def __add__(self, other: object) -> Latex: """Concatenates two expressions. Args: other: The expression to be concatenated to the right side of self. Returns: A new expression: "{self}{other}" """ if isinstance(other, str): return Latex(self._raw + other) if isinstance(other, Latex): return Latex(self._raw + other._raw) raise ValueError("Unsupported operation.") def __radd__(self, other: object) -> Latex: """Concatenates two expressions. Args: other: The expression to be concatenated to the left side of self. Returns: A new expression: "{other}{self}" """ if isinstance(other, str): return Latex(other + self._raw) if isinstance(other, Latex): return Latex(other._raw + self._raw) raise ValueError("Unsupported operation.") @staticmethod def opt(src: LatexLike) -> Latex: """Wraps the expression by "[" and "]". This wrapping is used when the expression needs to be wrapped as an optional argument of the environment. Args: src: Original expression. Returns: A new expression with surrounding brackets. """ return Latex("[" + str(src) + "]") @staticmethod def arg(src: LatexLike) -> Latex: """Wraps the expression by "{" and "}". This wrapping is used when the expression needs to be wrapped as an argument of other expressions. Args: src: Original expression. Returns: A new expression with surrounding brackets. """ return Latex("{" + str(src) + "}") @staticmethod def paren(src: LatexLike) -> Latex: """Adds surrounding parentheses: "(" and ")". Args: src: Original expression. Returns: A new expression with surrounding brackets. """ return Latex(r"\mathopen{}\left( " + str(src) + r" \mathclose{}\right)") @staticmethod def curly(src: LatexLike) -> Latex: """Adds surrounding curly brackets: "\\{" and "\\}". Args: src: Original expression. Returns: A new expression with surrounding brackets. """ return Latex(r"\mathopen{}\left\{ " + str(src) + r" \mathclose{}\right\}") @staticmethod def square(src: LatexLike) -> Latex: """Adds surrounding square brackets: "[" and "]". Args: src: Original expression. Returns: A new expression with surrounding brackets. """ return Latex(r"\mathopen{}\left[ " + str(src) + r" \mathclose{}\right]") @staticmethod def command( name: str, *, options: list[LatexLike] | None = None, args: list[LatexLike] | None = None, ) -> Latex: """Makes a Latex command expression. Args: name: Name of the command. options: List of optional arguments. args: List of arguments. Returns: A new expression. """ elms: list[LatexLike] = [rf"\{name}"] if options is not None: elms += [Latex.opt(x) for x in options] if args is not None: elms += [Latex.arg(x) for x in args] return Latex.join("", elms) @staticmethod def environment( name: str, *, options: list[LatexLike] | None = None, args: list[LatexLike] | None = None, content: LatexLike | None = None, ) -> Latex: """Makes a Latex environment expression. Args: name: Name of the environment. options: List of optional arguments. args: List of arguments. content: Inner content of the environment. Returns: A new expression. """ begin_elms: list[LatexLike] = [rf"\begin{{{name}}}"] if options is not None: begin_elms += [Latex.opt(x) for x in options] if args is not None: begin_elms += [Latex.arg(x) for x in args] env_elms: list[LatexLike] = [Latex.join("", begin_elms)] if content is not None: env_elms.append(content) env_elms.append(rf"\end{{{name}}}") return Latex.join(" ", env_elms) @staticmethod def join(separator: LatexLike, elements: Iterable[LatexLike]) -> Latex: """Joins given sequence. Args: separator: Expression of the separator between each element. elements: Iterable of expressions to be joined. Returns: A new Latex: "{e[0]}{s}{e[1]}{s}...{s}{e[-1]}" where s == separator, and e == elements. """ return Latex(str(separator).join(str(x) for x in elements)) ================================================ FILE: src/latexify/codegen/latex_test.py ================================================ """Tests for latexify.codegen.latex.""" from __future__ import annotations # Ignores [22-imports] for convenience. from latexify.codegen.latex import Latex def test_eq() -> None: assert Latex("foo") == Latex("foo") assert Latex("foo") != "foo" assert Latex("foo") != Latex("bar") def test_str() -> None: assert str(Latex("foo")) == "foo" def test_add() -> None: assert Latex("foo") + "bar" == Latex("foobar") assert "foo" + Latex("bar") == Latex("foobar") assert Latex("foo") + Latex("bar") == Latex("foobar") def test_opt() -> None: assert Latex.opt("foo") == Latex("[foo]") assert Latex.opt(Latex("foo")) == Latex("[foo]") def test_arg() -> None: assert Latex.arg("foo") == Latex("{foo}") assert Latex.arg(Latex("foo")) == Latex("{foo}") def test_paren() -> None: assert Latex.paren("foo") == Latex(r"\mathopen{}\left( foo \mathclose{}\right)") assert Latex.paren(Latex("foo")) == Latex( r"\mathopen{}\left( foo \mathclose{}\right)" ) def test_curly() -> None: assert Latex.curly("foo") == Latex(r"\mathopen{}\left\{ foo \mathclose{}\right\}") assert Latex.curly(Latex("foo")) == Latex( r"\mathopen{}\left\{ foo \mathclose{}\right\}" ) def test_square() -> None: assert Latex.square("foo") == Latex(r"\mathopen{}\left[ foo \mathclose{}\right]") assert Latex.square(Latex("foo")) == Latex( r"\mathopen{}\left[ foo \mathclose{}\right]" ) def test_command() -> None: assert Latex.command("a") == Latex(r"\a") assert Latex.command("a", options=[]) == Latex(r"\a") assert Latex.command("a", options=["b"]) == Latex(r"\a[b]") assert Latex.command("a", options=[Latex("b")]) == Latex(r"\a[b]") assert Latex.command("a", options=["b", "c"]) == Latex(r"\a[b][c]") assert Latex.command("a", args=[]) == Latex(r"\a") assert Latex.command("a", args=["b"]) == Latex(r"\a{b}") assert Latex.command("a", args=[Latex("b")]) == Latex(r"\a{b}") assert Latex.command("a", args=["b", "c"]) == Latex(r"\a{b}{c}") assert Latex.command("a", options=["b"], args=["c"]) == Latex(r"\a[b]{c}") def test_environment() -> None: assert Latex.environment("a") == Latex(r"\begin{a} \end{a}") assert Latex.environment("a", options=[]) == Latex(r"\begin{a} \end{a}") assert Latex.environment("a", options=["b"]) == Latex(r"\begin{a}[b] \end{a}") assert Latex.environment("a", options=[Latex("b")]) == Latex( r"\begin{a}[b] \end{a}" ) assert Latex.environment("a", options=["b", "c"]) == Latex( r"\begin{a}[b][c] \end{a}" ) assert Latex.environment("a", args=[]) == Latex(r"\begin{a} \end{a}") assert Latex.environment("a", args=["b"]) == Latex(r"\begin{a}{b} \end{a}") assert Latex.environment("a", args=[Latex("b")]) == Latex(r"\begin{a}{b} \end{a}") assert Latex.environment("a", args=["b", "c"]) == Latex(r"\begin{a}{b}{c} \end{a}") assert Latex.environment("a", content="b") == Latex(r"\begin{a} b \end{a}") assert Latex.environment("a", content=Latex("b")) == Latex(r"\begin{a} b \end{a}") assert Latex.environment("a", options=["b"], args=["c"]) == Latex( r"\begin{a}[b]{c} \end{a}" ) assert Latex.environment("a", options=["b"], content="c") == Latex( r"\begin{a}[b] c \end{a}" ) assert Latex.environment("a", args=["b"], content="c") == Latex( r"\begin{a}{b} c \end{a}" ) assert Latex.environment("a", options=["b"], args=["c"], content="d") == Latex( r"\begin{a}[b]{c} d \end{a}" ) def test_join() -> None: assert Latex.join(":", []) == Latex("") assert Latex.join(":", ["foo"]) == Latex("foo") assert Latex.join(":", [Latex("foo")]) == Latex("foo") assert Latex.join(":", [Latex("foo"), "bar"]) == Latex("foo:bar") assert Latex.join(":", ["foo", Latex("bar")]) == Latex("foo:bar") assert Latex.join(":", [Latex("foo"), Latex("bar")]) == Latex("foo:bar") assert Latex.join(":", ()) == Latex("") assert Latex.join(":", ("foo",)) == Latex("foo") assert Latex.join(":", (Latex("foo"),)) == Latex("foo") assert Latex.join(":", (Latex("foo"), "bar")) == Latex("foo:bar") assert Latex.join(":", ("foo", Latex("bar"))) == Latex("foo:bar") assert Latex.join(":", (Latex("foo"), Latex("bar"))) == Latex("foo:bar") assert Latex.join(":", (str(x) for x in range(3))) == Latex("0:1:2") assert Latex.join(":", (Latex(str(x)) for x in range(3))) == Latex("0:1:2") ================================================ FILE: src/latexify/config.py ================================================ """Definition of the Config class.""" from __future__ import annotations import dataclasses from typing import Any @dataclasses.dataclass(frozen=True) class Config: """Configurations to control the behavior of latexify. Attributes: expand_functions: If set, the names of the functions to expand. identifiers: If set, the mapping to replace identifier names in the function. Keys are the original names of the identifiers, and corresponding values are the replacements. Both keys and values have to represent valid Python identifiers: ^[A-Za-z_][A-Za-z0-9_]*$ prefixes: Prefixes of identifiers to trim. E.g., if "foo.bar" in prefixes, all identifiers with the form "foo.bar.suffix" will be replaced to "suffix" reduce_assignments: If True, assignment statements are used to synthesize the final expression. use_math_symbols: Whether to convert identifiers with a math symbol surface (e.g., "alpha") to the LaTeX symbol (e.g., "\\alpha"). use_set_symbols: Whether to use set symbols or not. use_signature: Whether to add the function signature before the expression or not. """ expand_functions: set[str] | None identifiers: dict[str, str] | None prefixes: set[str] | None reduce_assignments: bool use_math_symbols: bool use_set_symbols: bool use_signature: bool escape_underscores: bool def merge(self, *, config: Config | None = None, **kwargs) -> Config: """Merge configuration based on old configuration and field values. Args: config: If None, the merged one will merge defaults and field values, instead of merging old configuration and field values. **kwargs: Members to modify. This value precedes both self and config. Returns: A new Config object """ def merge_field(name: str) -> Any: # Precedence: kwargs -> config -> self arg = kwargs.get(name) if arg is None: if config is not None: arg = getattr(config, name) else: arg = getattr(self, name) return arg return Config(**{f.name: merge_field(f.name) for f in dataclasses.fields(self)}) @staticmethod def defaults() -> Config: """Generates a Config with default values. Returns: A new Config with default values """ return Config( expand_functions=None, identifiers=None, prefixes=None, reduce_assignments=False, use_math_symbols=False, use_set_symbols=False, use_signature=True, escape_underscores=True, ) ================================================ FILE: src/latexify/exceptions.py ================================================ """Exceptions used in Latexify.""" class LatexifyError(Exception): """Base class of all Latexify exceptions. Subclasses of this exception does not mean incorrect use of the library by the user at the interface level. These exceptions inform users that Latexify went into something wrong during processing the given functions. These exceptions are usually captured by the frontend functions (e.g., `with_latex`) to prevent destroying the entire program. Errors caused by wrong inputs should raise built-in exceptions. """ ... class LatexifyNotSupportedError(LatexifyError): """Some subtree in the AST is not supported by the current implementation. This error is raised when the library discovered incompatible syntaxes due to lack of the implementation. Possibly this error would be resolved in the future. """ ... class LatexifySyntaxError(LatexifyError): """Some subtree in the AST is not supported. This error is raised when the library discovered syntaxes that are not possible to be processed anymore. This error is essential, and wouldn't be resolved in the future. """ ... ================================================ FILE: src/latexify/frontend.py ================================================ """Frontend interfaces of latexify.""" from __future__ import annotations from collections.abc import Callable from typing import Any, overload from latexify import ipython_wrappers @overload def algorithmic( fn: Callable[..., Any], **kwargs: Any ) -> ipython_wrappers.LatexifiedAlgorithm: ... @overload def algorithmic( **kwargs: Any, ) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm]: ... def algorithmic( fn: Callable[..., Any] | None = None, **kwargs: Any ) -> ( ipython_wrappers.LatexifiedAlgorithm | Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedAlgorithm] ): """Attach LaTeX pretty-printing to the given function. This function works with or without specifying the target function as the positional argument. The following two syntaxes works similarly. - latexify.algorithmic(alg, **kwargs) - latexify.algorithmic(**kwargs)(alg) Args: fn: Callable to be wrapped. **kwargs: Arguments to control behavior. See also get_latex(). Returns: - If `fn` is passed, returns the wrapped function. - Otherwise, returns the wrapper function with given settings. """ if fn is not None: return ipython_wrappers.LatexifiedAlgorithm(fn, **kwargs) def wrapper(f): return ipython_wrappers.LatexifiedAlgorithm(f, **kwargs) return wrapper @overload def function( fn: Callable[..., Any], **kwargs: Any ) -> ipython_wrappers.LatexifiedFunction: ... @overload def function( **kwargs: Any, ) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]: ... def function( fn: Callable[..., Any] | None = None, **kwargs: Any ) -> ( ipython_wrappers.LatexifiedFunction | Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction] ): """Attach LaTeX pretty-printing to the given function. This function works with or without specifying the target function as the positional argument. The following two syntaxes works similarly. - latexify.function(fn, **kwargs) - latexify.function(**kwargs)(fn) Args: fn: Callable to be wrapped. **kwargs: Arguments to control behavior. See also get_latex(). Returns: - If `fn` is passed, returns the wrapped function. - Otherwise, returns the wrapper function with given settings. """ if fn is not None: return ipython_wrappers.LatexifiedFunction(fn, **kwargs) def wrapper(f): return ipython_wrappers.LatexifiedFunction(f, **kwargs) return wrapper @overload def expression( fn: Callable[..., Any], **kwargs: Any ) -> ipython_wrappers.LatexifiedFunction: ... @overload def expression( **kwargs: Any, ) -> Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction]: ... def expression( fn: Callable[..., Any] | None = None, **kwargs: Any ) -> ( ipython_wrappers.LatexifiedFunction | Callable[[Callable[..., Any]], ipython_wrappers.LatexifiedFunction] ): """Attach LaTeX pretty-printing to the given function. This function is a shortcut for `latexify.function` with the default parameter `use_signature=False`. """ kwargs["use_signature"] = kwargs.get("use_signature", False) if fn is not None: return ipython_wrappers.LatexifiedFunction(fn, **kwargs) def wrapper(f): return ipython_wrappers.LatexifiedFunction(f, **kwargs) return wrapper ================================================ FILE: src/latexify/frontend_test.py ================================================ """Tests for latexify.frontend.""" from __future__ import annotations from latexify import frontend def test_function() -> None: def f(x): return x latex_without_flag = "x" latex_with_flag = r"f(x) = x" # Checks the syntax: # @function # def fn(...): # ... latexified = frontend.function(f) assert str(latexified) == latex_with_flag assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_with_flag} $$" # Checks the syntax: # @function(**kwargs) # def fn(...): # ... latexified = frontend.function(use_signature=False)(f) assert str(latexified) == latex_without_flag assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_without_flag} $$" # Checks the syntax: # def fn(...): # ... # latexified = function(fn, **kwargs) latexified = frontend.function(f, use_signature=False) assert str(latexified) == latex_without_flag assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_without_flag} $$" def test_expression() -> None: def f(x): return x latex_without_flag = "x" latex_with_flag = r"f(x) = x" # Checks the syntax: # @expression # def fn(...): # ... latexified = frontend.expression(f) assert str(latexified) == latex_without_flag assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_without_flag} $$" # Checks the syntax: # @expression(**kwargs) # def fn(...): # ... latexified = frontend.expression(use_signature=True)(f) assert str(latexified) == latex_with_flag assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_with_flag} $$" # Checks the syntax: # def fn(...): # ... # latexified = expression(fn, **kwargs) latexified = frontend.expression(f, use_signature=True) assert str(latexified) == latex_with_flag assert latexified._repr_latex_() == rf"$$ \displaystyle {latex_with_flag} $$" ================================================ FILE: src/latexify/generate_latex.py ================================================ """Generate LaTeX code.""" from __future__ import annotations import enum from collections.abc import Callable from typing import Any from latexify import codegen from latexify import config as cfg from latexify import parser, transformers class Style(enum.Enum): """The style of the generated LaTeX.""" ALGORITHMIC = "algorithmic" FUNCTION = "function" IPYTHON_ALGORITHMIC = "ipython-algorithmic" def get_latex( fn: Callable[..., Any], *, style: Style = Style.FUNCTION, config: cfg.Config | None = None, **kwargs, ) -> str: """Obtains LaTeX description from the function's source. Args: fn: Reference to a function to analyze. style: Style of the LaTeX description, the default is FUNCTION. config: Use defined Config object, if it is None, it will be automatic assigned with default value. **kwargs: Dict of Config field values that could be defined individually by users. Returns: Generated LaTeX description. Raises: latexify.exceptions.LatexifyError: Something went wrong during conversion. """ merged_config = cfg.Config.defaults().merge(config=config, **kwargs) # Obtains the source AST. tree = parser.parse_function(fn) # Mandatory AST Transformation. tree = transformers.AugAssignReplacer().visit(tree) # Conditional AST transformation. if merged_config.prefixes is not None: tree = transformers.PrefixTrimmer(merged_config.prefixes).visit(tree) if merged_config.identifiers is not None: tree = transformers.IdentifierReplacer(merged_config.identifiers).visit(tree) if merged_config.reduce_assignments: tree = transformers.DocstringRemover().visit(tree) tree = transformers.AssignmentReducer().visit(tree) if merged_config.expand_functions is not None: tree = transformers.FunctionExpander(merged_config.expand_functions).visit(tree) # Generates LaTeX. if style == Style.ALGORITHMIC: return codegen.AlgorithmicCodegen( use_math_symbols=merged_config.use_math_symbols, use_set_symbols=merged_config.use_set_symbols, escape_underscores=merged_config.escape_underscores, ).visit(tree) elif style == Style.FUNCTION: return codegen.FunctionCodegen( use_math_symbols=merged_config.use_math_symbols, use_signature=merged_config.use_signature, use_set_symbols=merged_config.use_set_symbols, escape_underscores=merged_config.escape_underscores, ).visit(tree) elif style == Style.IPYTHON_ALGORITHMIC: return codegen.IPythonAlgorithmicCodegen( use_math_symbols=merged_config.use_math_symbols, use_set_symbols=merged_config.use_set_symbols, escape_underscores=merged_config.escape_underscores, ).visit(tree) raise ValueError(f"Unrecognized style: {style}") ================================================ FILE: src/latexify/generate_latex_test.py ================================================ """Tests for latexify.generate_latex.""" from __future__ import annotations from latexify import generate_latex def test_get_latex_identifiers() -> None: def myfn(myvar): return 3 * myvar identifiers = {"myfn": "f", "myvar": "x"} latex_without_flag = r"\mathrm{myfn}(\mathrm{myvar}) = 3 \mathrm{myvar}" latex_with_flag = r"f(x) = 3 x" assert generate_latex.get_latex(myfn) == latex_without_flag assert generate_latex.get_latex(myfn, identifiers=identifiers) == latex_with_flag def test_get_latex_prefixes() -> None: abc = object() def f(x): return abc.d + x.y.z.e latex_without_flag = r"f(x) = \mathrm{abc}.d + x.y.z.e" latex_with_flag1 = r"f(x) = d + x.y.z.e" latex_with_flag2 = r"f(x) = \mathrm{abc}.d + y.z.e" latex_with_flag3 = r"f(x) = \mathrm{abc}.d + z.e" latex_with_flag4 = r"f(x) = d + e" assert generate_latex.get_latex(f) == latex_without_flag assert generate_latex.get_latex(f, prefixes=set()) == latex_without_flag assert generate_latex.get_latex(f, prefixes={"abc"}) == latex_with_flag1 assert generate_latex.get_latex(f, prefixes={"x"}) == latex_with_flag2 assert generate_latex.get_latex(f, prefixes={"x.y"}) == latex_with_flag3 assert generate_latex.get_latex(f, prefixes={"abc", "x.y.z"}) == latex_with_flag4 assert ( generate_latex.get_latex(f, prefixes={"abc", "x", "x.y.z"}) == latex_with_flag4 ) def test_get_latex_reduce_assignments() -> None: def f(x): y = 3 * x return y latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}" latex_with_flag = r"f(x) = 3 x" assert generate_latex.get_latex(f) == latex_without_flag assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag def test_get_latex_reduce_assignments_with_docstring() -> None: def f(x): """DocstringRemover is required.""" y = 3 * x return y latex_without_flag = r"\begin{array}{l} y = 3 x \\ f(x) = y \end{array}" latex_with_flag = r"f(x) = 3 x" assert generate_latex.get_latex(f) == latex_without_flag assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag def test_get_latex_reduce_assignments_with_aug_assign() -> None: def f(x): y = 3 y *= x return y latex_without_flag = r"\begin{array}{l} y = 3 \\ y = y x \\ f(x) = y \end{array}" latex_with_flag = r"f(x) = 3 x" assert generate_latex.get_latex(f) == latex_without_flag assert generate_latex.get_latex(f, reduce_assignments=False) == latex_without_flag assert generate_latex.get_latex(f, reduce_assignments=True) == latex_with_flag def test_get_latex_use_math_symbols() -> None: def f(alpha): return alpha latex_without_flag = r"f(\mathrm{alpha}) = \mathrm{alpha}" latex_with_flag = r"f(\alpha) = \alpha" assert generate_latex.get_latex(f) == latex_without_flag assert generate_latex.get_latex(f, use_math_symbols=False) == latex_without_flag assert generate_latex.get_latex(f, use_math_symbols=True) == latex_with_flag def test_get_latex_use_signature() -> None: def f(x): return x latex_without_flag = "x" latex_with_flag = r"f(x) = x" assert generate_latex.get_latex(f) == latex_with_flag assert generate_latex.get_latex(f, use_signature=False) == latex_without_flag assert generate_latex.get_latex(f, use_signature=True) == latex_with_flag def test_get_latex_use_set_symbols() -> None: def f(x, y): return x & y latex_without_flag = r"f(x, y) = x \mathbin{\&} y" latex_with_flag = r"f(x, y) = x \cap y" assert generate_latex.get_latex(f) == latex_without_flag assert generate_latex.get_latex(f, use_set_symbols=False) == latex_without_flag assert generate_latex.get_latex(f, use_set_symbols=True) == latex_with_flag ================================================ FILE: src/latexify/ipython_wrappers.py ================================================ """Wrapper objects for IPython to display output.""" from __future__ import annotations import abc from typing import Any, Callable, cast from latexify import exceptions, generate_latex class LatexifiedRepr(metaclass=abc.ABCMeta): """Object with LaTeX representation.""" _fn: Callable[..., Any] def __init__(self, fn: Callable[..., Any], **kwargs) -> None: self._fn = fn @property def __doc__(self) -> str | None: return self._fn.__doc__ @__doc__.setter def __doc__(self, val: str | None) -> None: self._fn.__doc__ = val @property def __name__(self) -> str: return self._fn.__name__ @__name__.setter def __name__(self, val: str) -> None: self._fn.__name__ = val # After Python 3.7 # @final def __call__(self, *args) -> Any: return self._fn(*args) @abc.abstractmethod def __str__(self) -> str: ... @abc.abstractmethod def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display HTML visualization.""" ... @abc.abstractmethod def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display LaTeX visualization.""" ... class LatexifiedAlgorithm(LatexifiedRepr): """Algorithm with latex representation.""" _latex: str | None _error: str | None _ipython_latex: str | None _ipython_error: str | None def __init__(self, fn: Callable[..., Any], **kwargs) -> None: super().__init__(fn) try: self._latex = generate_latex.get_latex( fn, style=generate_latex.Style.ALGORITHMIC, **kwargs ) self._error = None except exceptions.LatexifyError as e: self._latex = None self._error = f"{type(e).__name__}: {str(e)}" try: self._ipython_latex = generate_latex.get_latex( fn, style=generate_latex.Style.IPYTHON_ALGORITHMIC, **kwargs ) self._ipython_error = None except exceptions.LatexifyError as e: self._ipython_latex = None self._ipython_error = f"{type(e).__name__}: {str(e)}" def __str__(self) -> str: return self._latex if self._latex is not None else cast(str, self._error) def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display HTML visualization.""" return ( '' + self._ipython_error + "" if self._ipython_error is not None else None ) def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display LaTeX visualization.""" return ( f"$ {self._ipython_latex} $" if self._ipython_latex is not None else self._ipython_error ) class LatexifiedFunction(LatexifiedRepr): """Function with latex representation.""" _latex: str | None _error: str | None def __init__(self, fn: Callable[..., Any], **kwargs) -> None: super().__init__(fn, **kwargs) try: self._latex = self._latex = generate_latex.get_latex( fn, style=generate_latex.Style.FUNCTION, **kwargs ) self._error = None except exceptions.LatexifyError as e: self._latex = None self._error = f"{type(e).__name__}: {str(e)}" def __str__(self) -> str: return self._latex if self._latex is not None else cast(str, self._error) def _repr_html_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display HTML visualization.""" return ( '' + self._error + "" if self._error is not None else None ) def _repr_latex_(self) -> str | tuple[str, dict[str, Any]] | None: """IPython hook to display LaTeX visualization.""" return ( rf"$$ \displaystyle {self._latex} $$" if self._latex is not None else self._error ) ================================================ FILE: src/latexify/parser.py ================================================ """Parsing utilities.""" from __future__ import annotations import ast import inspect import textwrap from collections.abc import Callable from typing import Any import dill # type: ignore[import] from latexify import exceptions def parse_function(fn: Callable[..., Any]) -> ast.Module: """Parses given function. Args: fn: Target function. Returns: AST tree representing `fn`. """ try: source = inspect.getsource(fn) except Exception: # Maybe running on console. source = dill.source.getsource(fn) # Remove extra indentation so that ast.parse runs correctly. source = textwrap.dedent(source) tree = ast.parse(source) if not tree.body or not isinstance(tree.body[0], ast.FunctionDef): raise exceptions.LatexifySyntaxError("Not a function.") return tree ================================================ FILE: src/latexify/parser_test.py ================================================ """Tests for latexify.parser.""" from __future__ import annotations import ast import pytest from latexify import ast_utils, exceptions, parser, test_utils def test_parse_function_with_posonlyargs() -> None: def f(x): return x expected = ast.Module( body=[ ast_utils.create_function_def( name="f", args=ast.arguments( posonlyargs=[], args=[ast.arg(arg="x")], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[], ), body=[ast.Return(value=ast.Name(id="x", ctx=ast.Load()))], decorator_list=[], returns=None, type_comment=None, type_params=[], lineno=1, col_offset=0, end_lineno=2, end_col_offset=0, ) ], type_ignores=[], ) obtained = parser.parse_function(f) test_utils.assert_ast_equal(obtained, expected) def test_parse_function_with_lambda() -> None: with pytest.raises(exceptions.LatexifySyntaxError, match=r"^Not a function\.$"): parser.parse_function(lambda: ()) with pytest.raises(exceptions.LatexifySyntaxError, match=r"^Not a function\.$"): x = lambda: () # noqa: E731 parser.parse_function(x) ================================================ FILE: src/latexify/test_utils.py ================================================ """Test utilities.""" from __future__ import annotations import ast import functools import sys from collections.abc import Callable from typing import cast def require_at_least( minor: int, ) -> Callable[[Callable[..., None]], Callable[..., None]]: """Require the minimum minor version of Python 3 to run the test. Args: minor: Minimum minor version (inclusive) that the test case supports. Returns: A decorator function to wrap the test case function. """ def decorator(fn: Callable[..., None]) -> Callable[..., None]: @functools.wraps(fn) def wrapper(*args, **kwargs): if sys.version_info.minor < minor: return fn(*args, **kwargs) return wrapper return decorator def require_at_most( minor: int, ) -> Callable[[Callable[..., None]], Callable[..., None]]: """Require the maximum minor version of Python 3 to run the test. Args: minor: Maximum minor version (inclusive) that the test case supports. Returns: A decorator function to wrap the test case function. """ def decorator(fn: Callable[..., None]) -> Callable[..., None]: @functools.wraps(fn) def wrapper(*args, **kwargs): if sys.version_info.minor > minor: return fn(*args, **kwargs) return wrapper return decorator def ast_equal(observed: ast.AST, expected: ast.AST) -> bool: """Checks the equality between two ASTs. This function checks if `observed` contains at least the same subtree with `expected`. If `observed` has some extra branches that `expected` does not cover, it is ignored. Args: observed: An AST to check. expected: The expected AST. Returns: True if observed and expected represent the same AST, False otherwise. """ ignore_keys = {"lineno", "col_offset", "end_lineno", "end_col_offset", "kind"} if sys.version_info.minor <= 12: ignore_keys.add("type_params") try: assert type(observed) is type(expected) for k, ve in vars(expected).items(): if k in ignore_keys: continue vo = getattr(observed, k) # May cause AttributeError. if isinstance(ve, ast.AST): assert ast_equal(cast(ast.AST, vo), ve) elif isinstance(ve, list): vo = cast(list, vo) assert len(vo) == len(ve) assert all( ast_equal(cast(ast.AST, co), cast(ast.AST, ce)) for co, ce in zip(vo, ve) ) else: assert type(vo) is type(ve) assert vo == ve except (AssertionError, AttributeError): raise # raise to debug easier. return True def assert_ast_equal(observed: ast.AST, expected: ast.AST) -> None: """Asserts the equality between two ASTs. Args: observed: An AST to compare. expected: Another AST. Raises: AssertionError: observed and expected represent different ASTs. """ assert ast_equal( observed, expected ), f"""\ AST does not match. observed={ast.dump(observed, indent=4)} expected={ast.dump(expected, indent=4)} """ ================================================ FILE: src/latexify/transformers/__init__.py ================================================ """Package latexify.transformers.""" from latexify.transformers.assignment_reducer import AssignmentReducer from latexify.transformers.aug_assign_replacer import AugAssignReplacer from latexify.transformers.docstring_remover import DocstringRemover from latexify.transformers.function_expander import FunctionExpander from latexify.transformers.identifier_replacer import IdentifierReplacer from latexify.transformers.prefix_trimmer import PrefixTrimmer __all__ = [ "AssignmentReducer", "AugAssignReplacer", "DocstringRemover", "FunctionExpander", "IdentifierReplacer", "PrefixTrimmer", ] ================================================ FILE: src/latexify/transformers/assignment_reducer.py ================================================ """NodeTransformer to reduce assigned expressions.""" from __future__ import annotations import ast from typing import Any from latexify import ast_utils, exceptions class AssignmentReducer(ast.NodeTransformer): """NodeTransformer to reduce assigned expressions. This class replaces a functions with multiple assignments to a function with only single return. Example: def f(x): y = 2 + x z = 3 * y return 4 + z AssignmentReducer modifies the function above to below: def f(x): return 4 + 3 * (2 + x) """ _assignments: dict[str, ast.expr] | None = None # TODO(odashi): # Currently, this function does not care much about some expressions, e.g., # comprehensions or lambdas, which introduces inner scopes. # It may cause some mistakes in the resulting AST. def visit_FunctionDef(self, node: ast.FunctionDef) -> Any: """Visit a FunctionDef node.""" # Push stack parent_assignments = self._assignments self._assignments = {} for child in node.body[:-1]: if not isinstance(child, ast.Assign): raise exceptions.LatexifyNotSupportedError( "AssignmentReducer supports only Assign nodes, " f"but got: {type(child).__name__}" ) value = self.visit(child.value) for target in child.targets: if not isinstance(target, ast.Name): raise exceptions.LatexifyNotSupportedError( "AssignmentReducer does not recognize list/tuple " "decomposition." ) self._assignments[target.id] = value return_original = node.body[-1] if not isinstance(return_original, (ast.Return, ast.If)): raise exceptions.LatexifySyntaxError( f"Unsupported last statement: {type(return_original).__name__}" ) return_transformed = self.visit(return_original) # Pop stack self._assignments = parent_assignments type_params = getattr(node, "type_params", []) return ast_utils.create_function_def( name=node.name, args=node.args, body=[return_transformed], decorator_list=node.decorator_list, returns=node.returns, type_params=type_params, ) def visit_Name(self, node: ast.Name) -> Any: """Visit a Name node.""" if self._assignments is not None: return self._assignments.get(node.id, node) return node ================================================ FILE: src/latexify/transformers/assignment_reducer_test.py ================================================ """Tests for latexify.transformers.assignment_reducer.""" from __future__ import annotations import ast from latexify import ast_utils, parser, test_utils from latexify.transformers.assignment_reducer import AssignmentReducer def _make_ast(body: list[ast.stmt]) -> ast.Module: """Helper function to generate an AST for f(x). Args: body: The function body. Returns: Generated AST. """ return ast.Module( body=[ ast_utils.create_function_def( name="f", args=ast.arguments( args=[ast.arg(arg="x")], kwonlyargs=[], kw_defaults=[], defaults=[], posonlyargs=[], ), body=body, decorator_list=[], type_params=[], ) ], type_ignores=[], ) def test_unchanged() -> None: def f(x): return x expected = _make_ast( [ ast.Return(value=ast.Name(id="x", ctx=ast.Load())), ] ) transformed = AssignmentReducer().visit(parser.parse_function(f)) test_utils.assert_ast_equal(transformed, expected) def test_constant() -> None: def f(x): y = 0 return y expected = _make_ast( [ ast.Return(value=ast_utils.make_constant(0)), ] ) transformed = AssignmentReducer().visit(parser.parse_function(f)) test_utils.assert_ast_equal(transformed, expected) def test_nested() -> None: def f(x): y = 2 * x return y expected = _make_ast( [ ast.Return( value=ast.BinOp( left=ast_utils.make_constant(2), op=ast.Mult(), right=ast.Name(id="x", ctx=ast.Load()), ) ) ] ) transformed = AssignmentReducer().visit(parser.parse_function(f)) test_utils.assert_ast_equal(transformed, expected) def test_nested2() -> None: def f(x): y = 2 * x z = 3 + y return z expected = _make_ast( [ ast.Return( value=ast.BinOp( left=ast_utils.make_constant(3), op=ast.Add(), right=ast.BinOp( left=ast_utils.make_constant(2), op=ast.Mult(), right=ast.Name(id="x", ctx=ast.Load()), ), ) ) ] ) transformed = AssignmentReducer().visit(parser.parse_function(f)) test_utils.assert_ast_equal(transformed, expected) def test_overwrite() -> None: def f(x): y = 2 * x y = 3 + x return y expected = _make_ast( [ ast.Return( value=ast.BinOp( left=ast_utils.make_constant(3), op=ast.Add(), right=ast.Name(id="x", ctx=ast.Load()), ) ) ] ) transformed = AssignmentReducer().visit(parser.parse_function(f)) test_utils.assert_ast_equal(transformed, expected) ================================================ FILE: src/latexify/transformers/aug_assign_replacer.py ================================================ """Transformer to replace AugAssign to Assign.""" from __future__ import annotations import ast class AugAssignReplacer(ast.NodeTransformer): """NodeTransformer to replace AugAssign to corresponding Assign. AugAssign(target, op, value) => Assign([target], BinOp(target, op, value)) """ def visit_AugAssign(self, node: ast.AugAssign) -> ast.Assign: left_args = {**vars(node.target), "ctx": ast.Load()} left = type(node.target)(**left_args) return ast.Assign( targets=[node.target], value=ast.BinOp(left, node.op, node.value) ) ================================================ FILE: src/latexify/transformers/aug_assign_replacer_test.py ================================================ """Tests for latexify.transformers.aug_assign_replacer.""" import ast from latexify import test_utils from latexify.transformers.aug_assign_replacer import AugAssignReplacer def test_replace() -> None: tree = ast.AugAssign( target=ast.Name(id="x", ctx=ast.Store()), op=ast.Add(), value=ast.Name(id="y", ctx=ast.Load()), ) expected = ast.Assign( targets=[ast.Name(id="x", ctx=ast.Store())], value=ast.BinOp( left=ast.Name(id="x", ctx=ast.Load()), op=ast.Add(), right=ast.Name(id="y", ctx=ast.Load()), ), ) transformed = AugAssignReplacer().visit(tree) test_utils.assert_ast_equal(transformed, expected) ================================================ FILE: src/latexify/transformers/docstring_remover.py ================================================ """Transformer to remove all docstrings.""" from __future__ import annotations import ast from typing import Union from latexify import ast_utils class DocstringRemover(ast.NodeTransformer): """NodeTransformer to remove all docstrings. Docstrings here are detected as Expr nodes with a single string constant. """ def visit_Expr(self, node: ast.Expr) -> Union[ast.Expr, None]: if ast_utils.is_str(node.value): return None return node ================================================ FILE: src/latexify/transformers/docstring_remover_test.py ================================================ """Tests for latexify.transformers.docstring_remover.""" import ast from latexify import ast_utils, parser, test_utils from latexify.transformers.docstring_remover import DocstringRemover def test_remove_docstrings() -> None: def f(): """Test docstring.""" x = 42 f() # This Expr should not be removed. """This string constant should also be removed.""" return x tree = parser.parse_function(f).body[0] assert isinstance(tree, ast.FunctionDef) expected = ast_utils.create_function_def( name="f", body=[ ast.Assign( targets=[ast.Name(id="x", ctx=ast.Store())], value=ast_utils.make_constant(42), ), ast.Expr( value=ast.Call( func=ast.Name(id="f", ctx=ast.Load()), args=[], keywords=[] ) ), ast.Return(value=ast.Name(id="x", ctx=ast.Load())), ], args=ast.arguments( posonlyargs=[], args=[], vararg=None, kwonlyargs=[], kw_defaults=[], kwarg=None, defaults=[], ), decorator_list=[], type_params=[], ) transformed = DocstringRemover().visit(tree) test_utils.assert_ast_equal(transformed, expected) ================================================ FILE: src/latexify/transformers/function_expander.py ================================================ from __future__ import annotations import ast import functools from collections.abc import Callable from latexify import ast_utils, exceptions # TODO(ZibingZhang): handle mutually recursive function expansions class FunctionExpander(ast.NodeTransformer): """NodeTransformer to expand functions. This class replaces function calls with an expanded form. Example: def f(x, y): return hypot(x, y) FunctionExpander({"hypot"}) will modify the AST of the function above to below: def f(x, y): return sqrt(x**2, y**2) """ def __init__(self, functions: set[str]) -> None: self._functions = functions def visit_Call(self, node: ast.Call) -> ast.AST: """Visit a Call node.""" func_name = ast_utils.extract_function_name_or_none(node) if ( func_name is not None and func_name in self._functions and func_name in _FUNCTION_EXPANDERS ): return _FUNCTION_EXPANDERS[func_name](self, node) kwargs = { "func": self.visit(node.func), "args": [self.visit(x) for x in node.args], } if hasattr(node, "keywords"): kwargs["keywords"] = [ ast.keyword(arg=x.arg, value=self.visit(x.value)) for x in node.keywords ] return ast.Call(**kwargs) def _atan2_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: _check_num_args(node, 2) return ast.Call( func=ast.Name(id="atan", ctx=ast.Load()), args=[ ast.BinOp( left=function_expander.visit(node.args[0]), op=ast.Div(), right=function_expander.visit(node.args[1]), ) ], keywords=[], ) def _exp_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: _check_num_args(node, 1) return ast.BinOp( left=ast.Name(id="e", ctx=ast.Load()), op=ast.Pow(), right=function_expander.visit(node.args[0]), ) def _exp2_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: _check_num_args(node, 1) return ast.BinOp( left=ast_utils.make_constant(2), op=ast.Pow(), right=function_expander.visit(node.args[0]), ) def _expm1_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: _check_num_args(node, 1) return ast.BinOp( left=function_expander.visit( ast.Call( func=ast.Name(id="exp", ctx=ast.Load()), args=[node.args[0]], keywords=[], ) ), op=ast.Sub(), right=ast_utils.make_constant(1), ) def _hypot_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: if not node.args: return ast_utils.make_constant(0) args = [ ast.BinOp( left=function_expander.visit(arg), op=ast.Pow(), right=ast_utils.make_constant(2), ) for arg in node.args ] args_reduced = functools.reduce( lambda a, b: ast.BinOp(left=a, op=ast.Add(), right=b), args ) return ast.Call( func=ast.Name(id="sqrt", ctx=ast.Load()), args=[args_reduced], keywords=[], ) def _log1p_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: _check_num_args(node, 1) return ast.Call( func=ast.Name(id="log", ctx=ast.Load()), args=[ ast.BinOp( left=ast_utils.make_constant(1), op=ast.Add(), right=function_expander.visit(node.args[0]), ) ], keywords=[], ) def _pow_expander(function_expander: FunctionExpander, node: ast.Call) -> ast.AST: _check_num_args(node, 2) return ast.BinOp( left=function_expander.visit(node.args[0]), op=ast.Pow(), right=function_expander.visit(node.args[1]), ) def _check_num_args(node: ast.Call, nargs: int) -> None: if len(node.args) != nargs: fn_name = ast_utils.extract_function_name_or_none(node) raise exceptions.LatexifySyntaxError( f"Incorrect number of arguments for {fn_name}." f" expected: {nargs}, but got {len(node.args)}" ) _FUNCTION_EXPANDERS: dict[str, Callable[[FunctionExpander, ast.Call], ast.AST]] = { "atan2": _atan2_expander, "exp": _exp_expander, "exp2": _exp2_expander, "expm1": _expm1_expander, "hypot": _hypot_expander, "log1p": _log1p_expander, "pow": _pow_expander, } ================================================ FILE: src/latexify/transformers/function_expander_test.py ================================================ """Tests for latexify.transformers.function_expander.""" from __future__ import annotations import ast from latexify import ast_utils, test_utils from latexify.transformers.function_expander import FunctionExpander def test_preserve_keywords() -> None: tree = ast.Call( func=ast_utils.make_name("f"), args=[ast_utils.make_name("x")], keywords=[ast.keyword(arg="y", value=ast_utils.make_constant(0))], ) expected = ast.Call( func=ast_utils.make_name("f"), args=[ast_utils.make_name("x")], keywords=[ast.keyword(arg="y", value=ast_utils.make_constant(0))], ) transformed = FunctionExpander(set()).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_exp() -> None: tree = ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("e"), op=ast.Pow(), right=ast_utils.make_name("x"), ) transformed = FunctionExpander({"exp"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_exp_unchanged() -> None: tree = ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], keywords=[], ) expected = ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], keywords=[], ) transformed = FunctionExpander(set()).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_exp_with_attribute() -> None: tree = ast.Call( func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), args=[ast_utils.make_name("x")], keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("e"), op=ast.Pow(), right=ast_utils.make_name("x"), ) transformed2 = FunctionExpander({"exp"}).visit(tree) test_utils.assert_ast_equal(transformed2, expected) def test_exp_unchanged_with_attribute() -> None: tree = ast.Call( func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), args=[ast_utils.make_name("x")], keywords=[], ) expected = ast.Call( func=ast_utils.make_attribute(ast_utils.make_name("math"), "exp"), args=[ast_utils.make_name("x")], keywords=[], ) transformed = FunctionExpander(set()).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_exp_nested1() -> None: tree = ast.Call( func=ast_utils.make_name("exp"), args=[ ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], keywords=[], ) ], keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("e"), op=ast.Pow(), right=ast.BinOp( left=ast_utils.make_name("e"), op=ast.Pow(), right=ast_utils.make_name("x"), ), ) transformed = FunctionExpander({"exp"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_exp_nested2() -> None: tree = ast.Call( func=ast_utils.make_name("f"), args=[ ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], keywords=[], ) ], keywords=[], ) expected = ast.Call( func=ast_utils.make_name("f"), args=[ ast.BinOp( left=ast_utils.make_name("e"), op=ast.Pow(), right=ast_utils.make_name("x"), ) ], keywords=[], ) transformed = FunctionExpander({"exp"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_atan2() -> None: tree = ast.Call( func=ast_utils.make_name("atan2"), args=[ast_utils.make_name("y"), ast_utils.make_name("x")], keywords=[], ) expected = ast.Call( func=ast_utils.make_name("atan"), args=[ ast.BinOp( left=ast_utils.make_name("y"), op=ast.Div(), right=ast_utils.make_name("x"), ) ], keywords=[], ) transformed = FunctionExpander({"atan2"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_exp2() -> None: tree = ast.Call( func=ast_utils.make_name("exp2"), args=[ast_utils.make_name("x")], keywords=[], ) expected = ast.BinOp( left=ast_utils.make_constant(2), op=ast.Pow(), right=ast_utils.make_name("x"), ) transformed = FunctionExpander({"exp2"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_expm1() -> None: tree = ast.Call( func=ast_utils.make_name("expm1"), args=[ast_utils.make_name("x")], keywords=[], ) expected = ast.BinOp( left=ast.Call( func=ast_utils.make_name("exp"), args=[ast_utils.make_name("x")], keywords=[], ), op=ast.Sub(), right=ast_utils.make_constant(1), ) transformed = FunctionExpander({"expm1"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_hypot() -> None: tree = ast.Call( func=ast_utils.make_name("hypot"), args=[ast_utils.make_name("x"), ast_utils.make_name("y")], keywords=[], ) expected = ast.Call( func=ast_utils.make_name("sqrt"), args=[ ast.BinOp( left=ast.BinOp( left=ast_utils.make_name("x"), op=ast.Pow(), right=ast_utils.make_constant(2), ), op=ast.Add(), right=ast.BinOp( left=ast_utils.make_name("y"), op=ast.Pow(), right=ast_utils.make_constant(2), ), ) ], keywords=[], ) transformed = FunctionExpander({"hypot"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_hypot_no_args() -> None: tree = ast.Call(func=ast_utils.make_name("hypot"), args=[], keywords=[]) expected = ast_utils.make_constant(0) transformed = FunctionExpander({"hypot"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_log1p() -> None: tree = ast.Call( func=ast_utils.make_name("log1p"), args=[ast_utils.make_name("x")], keywords=[], ) expected = ast.Call( func=ast_utils.make_name("log"), args=[ ast.BinOp( left=ast_utils.make_constant(1), op=ast.Add(), right=ast_utils.make_name("x"), ) ], keywords=[], ) transformed = FunctionExpander({"log1p"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) def test_pow() -> None: tree = ast.Call( func=ast_utils.make_name("pow"), args=[ast_utils.make_name("x"), ast_utils.make_name("y")], keywords=[], ) expected = ast.BinOp( left=ast_utils.make_name("x"), op=ast.Pow(), right=ast_utils.make_name("y"), ) transformed = FunctionExpander({"pow"}).visit(tree) test_utils.assert_ast_equal(transformed, expected) ================================================ FILE: src/latexify/transformers/identifier_replacer.py ================================================ """Transformer to replace user symbols.""" from __future__ import annotations import ast import keyword from typing import cast from latexify import ast_utils class IdentifierReplacer(ast.NodeTransformer): """NodeTransformer to replace identifier names. This class defines a rule to replace identifiers in AST with specified names. Example: def foo(bar): return baz IdentifierReplacer({"foo": "x", "bar": "y", "baz": "z"}) will modify the AST of the function above to below: def x(y): return z """ def __init__(self, mapping: dict[str, str]): """Initializer. Args: mapping: User defined mapping of names. Keys are the original names of the identifiers, and corresponding values are the replacements. Both keys and values have to represent valid Python identifiers: ^[A-Za-z_][A-Za-z0-9_]*$ """ self._mapping = mapping for k, v in self._mapping.items(): if not str.isidentifier(k) or keyword.iskeyword(k): raise ValueError(f"'{k}' is not an identifier name.") if not str.isidentifier(v) or keyword.iskeyword(v): raise ValueError(f"'{v}' is not an identifier name.") def _replace_args(self, args: list[ast.arg]) -> list[ast.arg]: """Helper function to replace arg names.""" return [ast.arg(arg=self._mapping.get(a.arg, a.arg)) for a in args] def visit_FunctionDef(self, node: ast.FunctionDef) -> ast.FunctionDef: """Visit a FunctionDef node.""" visited = cast(ast.FunctionDef, super().generic_visit(node)) args = ast.arguments( posonlyargs=self._replace_args(visited.args.posonlyargs), args=self._replace_args(visited.args.args), vararg=visited.args.vararg, kwonlyargs=self._replace_args(visited.args.kwonlyargs), kw_defaults=visited.args.kw_defaults, kwarg=visited.args.kwarg, defaults=visited.args.defaults, ) type_params = getattr(visited, "type_params", []) return ast_utils.create_function_def( name=self._mapping.get(visited.name, visited.name), args=args, body=visited.body, decorator_list=visited.decorator_list, returns=visited.returns, type_params=type_params, ) def visit_Name(self, node: ast.Name) -> ast.Name: """Visit a Name node.""" return ast.Name( id=self._mapping.get(node.id, node.id), ctx=node.ctx, ) ================================================ FILE: src/latexify/transformers/identifier_replacer_test.py ================================================ """Tests for latexify.transformer.identifier_replacer.""" from __future__ import annotations import ast import pytest from latexify import ast_utils, test_utils from latexify.transformers.identifier_replacer import IdentifierReplacer def test_invalid_mapping() -> None: with pytest.raises(ValueError, match=r"'123' is not an identifier name."): IdentifierReplacer({"123": "foo"}) with pytest.raises(ValueError, match=r"'456' is not an identifier name."): IdentifierReplacer({"foo": "456"}) with pytest.raises(ValueError, match=r"'def' is not an identifier name."): IdentifierReplacer({"foo": "def"}) def test_name_replaced() -> None: source = ast.Name(id="foo", ctx=ast.Load()) expected = ast.Name(id="bar", ctx=ast.Load()) transformed = IdentifierReplacer({"foo": "bar"}).visit(source) test_utils.assert_ast_equal(transformed, expected) def test_name_not_replaced() -> None: source = ast.Name(id="foo", ctx=ast.Load()) expected = ast.Name(id="foo", ctx=ast.Load()) transformed = IdentifierReplacer({"fo": "bar"}).visit(source) test_utils.assert_ast_equal(transformed, expected) transformed = IdentifierReplacer({"fooo": "bar"}).visit(source) test_utils.assert_ast_equal(transformed, expected) def test_functiondef_with_posonlyargs() -> None: # Subtree of: # @d # def f(x=a, /, y=b, *, z=c): # pass source = ast_utils.create_function_def( name="f", args=ast.arguments( posonlyargs=[ast.arg(arg="x")], args=[ast.arg(arg="y")], kwonlyargs=[ast.arg(arg="z")], kw_defaults=[ast.Name(id="c", ctx=ast.Load())], defaults=[ ast.Name(id="a", ctx=ast.Load()), ast.Name(id="b", ctx=ast.Load()), ], ), body=[ast.Pass()], decorator_list=[ast.Name(id="d", ctx=ast.Load())], returns=None, type_comment=None, type_params=[], ) expected = ast_utils.create_function_def( name="F", args=ast.arguments( posonlyargs=[ast.arg(arg="X")], args=[ast.arg(arg="Y")], kwonlyargs=[ast.arg(arg="Z")], kw_defaults=[ast.Name(id="C", ctx=ast.Load())], defaults=[ ast.Name(id="A", ctx=ast.Load()), ast.Name(id="B", ctx=ast.Load()), ], ), body=[ast.Pass()], decorator_list=[ast.Name(id="D", ctx=ast.Load())], returns=None, type_comment=None, type_params=[], ) mapping = {x: x.upper() for x in "abcdfxyz"} transformed = IdentifierReplacer(mapping).visit(source) test_utils.assert_ast_equal(transformed, expected) def test_expr() -> None: # Subtree of: # (x + y) * z source = ast.BinOp( left=ast.BinOp( left=ast.Name(id="x", ctx=ast.Load()), op=ast.Add(), right=ast.Name(id="y", ctx=ast.Load()), ), op=ast.Mult(), right=ast.Name(id="z", ctx=ast.Load()), ) expected = ast.BinOp( left=ast.BinOp( left=ast.Name(id="X", ctx=ast.Load()), op=ast.Add(), right=ast.Name(id="Y", ctx=ast.Load()), ), op=ast.Mult(), right=ast.Name(id="Z", ctx=ast.Load()), ) mapping = {x: x.upper() for x in "xyz"} transformed = IdentifierReplacer(mapping).visit(source) test_utils.assert_ast_equal(transformed, expected) ================================================ FILE: src/latexify/transformers/prefix_trimmer.py ================================================ """NodeTransformer to trim unnecessary prefixes.""" from __future__ import annotations import ast import re from latexify import ast_utils _PREFIX_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*(\.[A-Za-z_][A-Za-z0-9_]*)*$") class PrefixTrimmer(ast.NodeTransformer): """NodeTransformer to trim unnecessary prefixes. This class investigates all Attribute subtrees, and replace them if the prefix of the attribute matches the given set of prefixes. Prefix is searched in the manner of leftmost longest matching. Example: def f(x): return math.sqrt(x) PrefixTrimmer({"math"}) will modify the AST of the function above to below: def f(x): return sqrt(x) """ _prefixes: list[tuple[str, ...]] def __init__(self, prefixes: set[str]) -> None: """Initializer. Args: prefixes: Set of prefixes to be trimmed. Nested prefix is allowed too. Each value must follow one of the following formats: - A Python identifier, e.g., "math" - Python identifiers joined by periods, e.g., "numpy.random" """ for p in prefixes: if not _PREFIX_PATTERN.match(p): raise ValueError(f"Invalid prefix: {p}") self._prefixes = [tuple(p.split(".")) for p in prefixes] def _get_prefix(self, node: ast.expr) -> tuple[str, ...] | None: """Helper to obtain nested prefix. Args: node: Node to investigate. Returns: The prefix tuple, or None if the node has unsupported syntax. """ if isinstance(node, ast.Name): return (node.id,) if isinstance(node, ast.Attribute): parent = self._get_prefix(node.value) return parent + (node.attr,) if parent is not None else None return None def _make_attribute(self, prefix: tuple[str, ...], name: str) -> ast.expr: """Helper to generate a new Attribute or Name node. Args: prefix: List of prefixes. name: Attribute name. Returns: Name node if prefix == (), (possibly nested) Attribute node otherwise. """ if not prefix: return ast_utils.make_name(name) parent = self._make_attribute(prefix[:-1], prefix[-1]) return ast_utils.make_attribute(parent, name) def visit_Attribute(self, node: ast.Attribute) -> ast.expr: """Visit an Attribute node.""" prefix = self._get_prefix(node.value) if prefix is None: return node # Performs leftmost longest match. # NOTE(odashi): # This implementation is very naive, but would work efficiently as long as the # number of patterns is small. matched_length = 0 for p in self._prefixes: length = min(len(p), len(prefix)) if prefix[:length] == p and length > matched_length: matched_length = length return self._make_attribute(prefix[matched_length:], node.attr) ================================================ FILE: src/latexify/transformers/prefix_trimmer_test.py ================================================ """Tests for latexify.transformers.prefix_trimmer.""" from __future__ import annotations import ast import pytest from latexify import ast_utils, test_utils from latexify.transformers import prefix_trimmer # For convenience make_name = ast_utils.make_name make_attr = ast_utils.make_attribute PrefixTrimmer = prefix_trimmer.PrefixTrimmer @pytest.mark.parametrize( "prefix", [".x", "x.", "1", "1x", "x.1", "x.1x", "x.x.1", "x.x.1x" "x..x", "x.x..x"] ) def test_invalid_prefix(prefix: str) -> None: with pytest.raises(ValueError, match=rf"^Invalid prefix: {prefix}$"): PrefixTrimmer({prefix}) @pytest.mark.parametrize( "prefixes,expected", [ (set(), make_name("foo")), ({"foo"}, make_name("foo")), ({"bar"}, make_name("foo")), ({"foo.bar"}, make_name("foo")), ({"foo", "bar"}, make_name("foo")), ({"foo", "foo.bar"}, make_name("foo")), ], ) def test_name(prefixes: set[str], expected: ast.expr) -> None: source = make_name("foo") transformed = PrefixTrimmer(prefixes).visit(source) test_utils.assert_ast_equal(transformed, expected) @pytest.mark.parametrize( "prefixes,expected", [ (set(), make_attr(make_name("foo"), "bar")), ({"fo"}, make_attr(make_name("foo"), "bar")), ({"foo"}, make_name("bar")), ({"bar"}, make_attr(make_name("foo"), "bar")), ({"baz"}, make_attr(make_name("foo"), "bar")), ({"foo.bar"}, make_attr(make_name("foo"), "bar")), ({"foo", "bar"}, make_name("bar")), ({"foo", "foo.bar"}, make_name("bar")), ], ) def test_attr_1(prefixes: set[str], expected: ast.expr) -> None: source = make_attr(make_name("foo"), "bar") transformed = PrefixTrimmer(prefixes).visit(source) test_utils.assert_ast_equal(transformed, expected) @pytest.mark.parametrize( "prefixes,expected", [ (set(), make_attr(make_attr(make_name("foo"), "bar"), "baz")), ({"fo"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")), ({"foo"}, make_attr(make_name("bar"), "baz")), ({"bar"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")), ({"baz"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")), ({"foo.ba"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")), ({"foo.bar"}, make_name("baz")), ({"foo.bar.baz"}, make_attr(make_attr(make_name("foo"), "bar"), "baz")), ({"foo", "bar"}, make_attr(make_name("bar"), "baz")), ({"foo", "foo.bar"}, make_name("baz")), ], ) def test_attr_2(prefixes: set[str], expected: ast.expr) -> None: source = make_attr(make_attr(make_name("foo"), "bar"), "baz") transformed = PrefixTrimmer(prefixes).visit(source) test_utils.assert_ast_equal(transformed, expected)