Repository: m3dev/gokart Branch: master Commit: 0d0609000123 Files: 125 Total size: 466.8 KB Directory structure: gitextract_mhb_xs2d/ ├── .github/ │ ├── CODEOWNERS │ └── workflows/ │ ├── format.yml │ ├── publish.yml │ └── test.yml ├── .gitignore ├── .readthedocs.yaml ├── LICENSE ├── README.md ├── docs/ │ ├── Makefile │ ├── conf.py │ ├── efficient_run_on_multi_workers.rst │ ├── for_pandas.rst │ ├── gokart.rst │ ├── index.rst │ ├── intro_to_gokart.rst │ ├── logging.rst │ ├── make.bat │ ├── mypy_plugin.rst │ ├── polars.rst │ ├── requirements.txt │ ├── setting_task_parameters.rst │ ├── slack_notification.rst │ ├── task_information.rst │ ├── task_on_kart.rst │ ├── task_parameters.rst │ ├── task_settings.rst │ ├── tutorial.rst │ └── using_task_task_conflict_prevention_lock.rst ├── examples/ │ ├── gokart_notebook_example.ipynb │ ├── logging.ini │ └── param.ini ├── gokart/ │ ├── __init__.py │ ├── build.py │ ├── config_params.py │ ├── conflict_prevention_lock/ │ │ ├── task_lock.py │ │ └── task_lock_wrappers.py │ ├── errors/ │ │ └── __init__.py │ ├── file_processor/ │ │ ├── __init__.py │ │ ├── base.py │ │ ├── pandas.py │ │ └── polars.py │ ├── file_processor.py │ ├── gcs_config.py │ ├── gcs_obj_metadata_client.py │ ├── gcs_zip_client.py │ ├── in_memory/ │ │ ├── __init__.py │ │ ├── data.py │ │ ├── repository.py │ │ └── target.py │ ├── info.py │ ├── mypy.py │ ├── object_storage.py │ ├── pandas_type_config.py │ ├── parameter.py │ ├── py.typed │ ├── required_task_output.py │ ├── run.py │ ├── s3_config.py │ ├── s3_zip_client.py │ ├── slack/ │ │ ├── __init__.py │ │ ├── event_aggregator.py │ │ ├── slack_api.py │ │ └── slack_config.py │ ├── target.py │ ├── task.py │ ├── task_complete_check.py │ ├── testing/ │ │ ├── __init__.py │ │ ├── check_if_run_with_empty_data_frame.py │ │ └── pandas_assert.py │ ├── tree/ │ │ ├── task_info.py │ │ └── task_info_formatter.py │ ├── utils.py │ ├── worker.py │ ├── workspace_management.py │ ├── zip_client.py │ └── zip_client_util.py ├── luigi.cfg ├── pyproject.toml ├── test/ │ ├── __init__.py │ ├── config/ │ │ ├── __init__.py │ │ ├── pyproject.toml │ │ ├── pyproject_disallow_missing_parameters.toml │ │ └── test_config.ini │ ├── conflict_prevention_lock/ │ │ ├── __init__.py │ │ ├── test_task_lock.py │ │ └── test_task_lock_wrappers.py │ ├── file_processor/ │ │ ├── __init__.py │ │ ├── test_base.py │ │ ├── test_factory.py │ │ ├── test_pandas.py │ │ └── test_polars.py │ ├── in_memory/ │ │ ├── test_in_memory_target.py │ │ └── test_repository.py │ ├── slack/ │ │ ├── __init__.py │ │ └── test_slack_api.py │ ├── test_build.py │ ├── test_cache_unique_id.py │ ├── test_config_params.py │ ├── test_explicit_bool_parameter.py │ ├── test_gcs_config.py │ ├── test_gcs_obj_metadata_client.py │ ├── test_info.py │ ├── test_large_data_fram_processor.py │ ├── test_list_task_instance_parameter.py │ ├── test_mypy.py │ ├── test_pandas_type_check_framework.py │ ├── test_pandas_type_config.py │ ├── test_restore_task_by_id.py │ ├── test_run.py │ ├── test_s3_config.py │ ├── test_s3_zip_client.py │ ├── test_serializable_parameter.py │ ├── test_target.py │ ├── test_task_instance_parameter.py │ ├── test_task_on_kart.py │ ├── test_utils.py │ ├── test_worker.py │ ├── test_zoned_date_second_parameter.py │ ├── testing/ │ │ ├── __init__.py │ │ └── test_pandas_assert.py │ ├── tree/ │ │ ├── __init__.py │ │ ├── test_task_info.py │ │ └── test_task_info_formatter.py │ └── util.py └── tox.ini ================================================ FILE CONTENTS ================================================ ================================================ FILE: .github/CODEOWNERS ================================================ * @Hi-king @yokomotod @hirosassa @mski-iksm @kitagry @ujiuji1259 @mamo3gr @hiro-o918 ================================================ FILE: .github/workflows/format.yml ================================================ name: Lint on: push: branches: [ master ] pull_request: jobs: formatting-check: name: Lint runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true - name: Install dependencies run: | uv tool install --python-preference only-managed --python 3.13 tox --with tox-uv - name: Run ruff and mypy run: | uvx --with tox-uv tox run -e ruff,mypy ================================================ FILE: .github/workflows/publish.yml ================================================ name: Publish on: push: tags: '*' jobs: deploy: runs-on: ubuntu-latest steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true - name: Build and publish env: UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }} run: | uv build uv publish ================================================ FILE: .github/workflows/test.yml ================================================ name: Test on: push: branches: [ master ] pull_request: jobs: tests: runs-on: ${{ matrix.platform }} strategy: max-parallel: 7 matrix: platform: ["ubuntu-latest"] tox-env: ["py310", "py311", "py312", "py313", "py314"] include: - platform: macos-15 tox-env: "py313" - platform: macos-latest tox-env: "py313" steps: - uses: actions/checkout@v6 - name: Set up the latest version of uv uses: astral-sh/setup-uv@v7 with: enable-cache: true - name: Install dependencies run: | uv tool install --python-preference only-managed --python 3.13 tox --with tox-uv - name: Test with tox run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }} ================================================ FILE: .gitignore ================================================ # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions *.so # Distribution / packaging .Python build/ develop-eggs/ dist/ downloads/ eggs/ .eggs/ lib/ lib64/ parts/ sdist/ var/ wheels/ *.egg-info/ .installed.cfg *.egg MANIFEST # PyInstaller # Usually these files are written by a python script from a template # before PyInstaller builds the exe, so as to inject date/other infos into it. *.manifest *.spec # Installer logs pip-log.txt pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ .coverage .coverage.* .cache nosetests.xml coverage.xml *.cover .hypothesis/ .pytest_cache/ # Translations *.mo *.pot # Django stuff: *.log local_settings.py db.sqlite3 # Flask stuff: instance/ .webassets-cache # Scrapy stuff: .scrapy # Sphinx documentation docs/_build/ # PyBuilder target/ # Jupyter Notebook .ipynb_checkpoints # pyenv .python-version # celery beat schedule file celerybeat-schedule # SageMath parsed files *.sage.py # Environments .env .venv env/ venv/ ENV/ env.bak/ venv.bak/ # Spyder project settings .spyderproject .spyproject # Rope project settings .ropeproject # mkdocs documentation /site # mypy .mypy_cache/ # pycharm .idea # gokart resources examples/resources # poetry dist # temporary data temporary.zip ================================================ FILE: .readthedocs.yaml ================================================ # Read the Docs configuration file for Sphinx projects # See https://docs.readthedocs.io/en/stable/config-file/v2.html for details # Required version: 2 # Set the OS, Python version and other tools you might need build: os: ubuntu-24.04 tools: python: "3.12" # Build from the docs/ directory with Sphinx sphinx: configuration: docs/conf.py # Optional but recommended, declare the Python requirements required # to build your documentation # See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html python: install: - requirements: docs/requirements.txt ================================================ FILE: LICENSE ================================================ MIT License Copyright (c) 2018 M3, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================ # gokart

[![Test](https://github.com/m3dev/gokart/workflows/Test/badge.svg)](https://github.com/m3dev/gokart/actions?query=workflow%3ATest) [![](https://readthedocs.org/projects/gokart/badge/?version=latest)](https://gokart.readthedocs.io/en/latest/) [![Python Versions](https://img.shields.io/pypi/pyversions/gokart.svg)](https://pypi.org/project/gokart/) [![](https://img.shields.io/pypi/v/gokart)](https://pypi.org/project/gokart/) ![](https://img.shields.io/pypi/l/gokart) Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline. [Documentation](https://gokart.readthedocs.io/en/latest/) for the latest release is hosted on readthedocs. # About gokart Here are some good things about gokart. - The following meta data for each Task is stored separately in a `pkl` file with hash value - task output data - imported all module versions - task processing time - random seed in task - displayed log - all parameters set as class variables in the task - Automatically rerun the pipeline if parameters of Tasks are changed. - Support GCS and S3 as a data store for intermediate results of Tasks in the pipeline. - The above output is exchanged between tasks as an intermediate file, which is memory-friendly - `pandas.DataFrame` type and column checking during I/O - Directory structure of saved files is automatically determined from structure of script - Seeds for numpy and random are automatically fixed - Can code while adhering to [SOLID](https://en.wikipedia.org/wiki/SOLID) principles as much as possible - Tasks are locked via redis even if they run in parallel **All the functions above are created for constructing Machine Learning batches. Provides an excellent environment for reproducibility and team development.** Here are some non-goal / downside of the gokart. - Batch execution in parallel is supported, but parallel and concurrent execution of task in memory. - Gokart is focused on reproducibility. So, I/O and capacity of data storage can become a bottleneck. - No support for task visualize. - Gokart is not an experiment management tool. The management of the execution result is cut out as [Thunderbolt](https://github.com/m3dev/thunderbolt). - Gokart does not recommend writing pipelines in toml, yaml, json, and more. Gokart is preferring to write them in Python. # Getting Started Within the activated Python environment, use the following command to install gokart. ``` pip install gokart ``` # Quickstart ## Minimal Example A minimal gokart tasks looks something like this: ```python import gokart class Example(gokart.TaskOnKart): def run(self): self.dump('Hello, world!') task = Example() output = gokart.build(task) print(output) ``` `gokart.build` return the result of dump by `gokart.TaskOnKart`. The example will output the following. ``` Hello, world! ``` ## Type-Safe Pipeline Example We introduce type-annotations to make a gokart pipeline robust. Please check the following example to see how to use type-annotations on gokart. Before using this feature, ensure to enable [mypy plugin](https://gokart.readthedocs.io/en/latest/mypy_plugin.html) feature in your project. ```python import gokart # `gokart.TaskOnKart[str]` means that the task dumps `str` class StrDumpTask(gokart.TaskOnKart[str]): def run(self): self.dump('Hello, world!') # `gokart.TaskOnKart[int]` means that the task dumps `int` class OneDumpTask(gokart.TaskOnKart[int]): def run(self): self.dump(1) # `gokart.TaskOnKart[int]` means that the task dumps `int` class TwoDumpTask(gokart.TaskOnKart[int]): def run(self): self.dump(2) class AddTask(gokart.TaskOnKart[int]): # `a` requires a task to dump `int` a: gokart.TaskInstanceParameter[gokart.TaskOnKart[int]] = gokart.TaskInstanceParameter() # `b` requires a task to dump `int` b: gokart.TaskInstanceParameter[gokart.TaskOnKart[int]] = gokart.TaskInstanceParameter() def requires(self): return dict(a=self.a, b=self.b) def run(self): # loading by instance parameter, # `a` and `b` are treated as `int` # because they are declared as `gokart.TaskOnKart[int]` a = self.load(self.a) b = self.load(self.b) self.dump(a + b) valid_task = AddTask(a=OneDumpTask(), b=TwoDumpTask()) # the next line will show type error by mypy # because `StrDumpTask` dumps `str` and `AddTask` requires `int` invalid_task = AddTask(a=OneDumpTask(), b=StrDumpTask()) ``` This is an introduction to some of the gokart. There are still more useful features. Please See [Documentation](https://gokart.readthedocs.io/en/latest/) . Have a good gokart life. # Achievements Gokart is a proven product. - It's actually been used by [m3.inc](https://corporate.m3.com/en) for over 3 years - Natural Language Processing Competition by [Nishika.inc](https://nishika.com) 2nd prize : [Solution Repository](https://github.com/vaaaaanquish/nishika_akutagawa_2nd_prize) # Thanks gokart is a wrapper for luigi. Thanks to luigi and dependent projects! - [luigi](https://github.com/spotify/luigi) ================================================ FILE: docs/Makefile ================================================ # Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SOURCEDIR = . BUILDDIR = _build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) ================================================ FILE: docs/conf.py ================================================ # https://github.com/sphinx-doc/sphinx/issues/6211 import luigi import gokart luigi.task.Task.requires.__doc__ = gokart.task.TaskOnKart.requires.__doc__ luigi.task.Task.output.__doc__ = gokart.task.TaskOnKart.output.__doc__ # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/master/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os # import sys # sys.path.insert(0, os.path.abspath('../gokart/')) # -- Project information ----------------------------------------------------- project = 'gokart' copyright = '2019, Masahiro Nishiba' author = 'Masahiro Nishiba' # The short X.Y version version = '' # The full version, including alpha/beta/rc tags release = '' # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode'] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The master toctree document. master_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store'] # The name of the Pygments (syntax highlighting) style to use. pygments_style = None # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'sphinx_rtd_theme' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". html_static_path = [] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'gokartdoc' # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'gokart.tex', 'gokart Documentation', 'Masahiro Nishiba', 'manual'), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [(master_doc, 'gokart', 'gokart Documentation', [author], 1)] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'gokart', 'gokart Documentation', author, 'gokart', 'One line description of project.', 'Miscellaneous'), ] # -- Options for Epub output ------------------------------------------------- # Bibliographic Dublin Core info. epub_title = project # The unique identifier of the text. This can be a ISBN number # or the project homepage. # # epub_identifier = '' # A unique identification for the text. # # epub_uid = '' # A list of files that should not be packed into the epub file. epub_exclude_files = ['search.html'] ================================================ FILE: docs/efficient_run_on_multi_workers.rst ================================================ How to improve efficiency when running on multiple workers =========================================================== If multiple worker nodes are running similar gokart pipelines in parallel, it is possible that the exact same task may be executed by multiple workers. (For example, when training multiple machine learning models with different parameters, the feature creation task in the first stage is expected to be exactly the same.) It is inefficient to execute the same task on each of multiple worker nodes, so we want to avoid this. Here we introduce `should_lock_run` feature to improve this inefficiency. Suppress run() of the same task with `should_lock_run` ------------------------------------------------------ When `gokart.TaskOnKart.should_lock_run` is set to True, the task will fail if the same task is run()-ing by another worker. By failing the task, other tasks that can be executed at that time are given priority. After that, the failed task is automatically re-executed. .. code:: python class SampleTask2(gokart.TaskOnKart): should_lock_run = True Additional Option ------------------ Skip completed tasks with `complete_check_at_run` ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ By setting `gokart.TaskOnKart.complete_check_at_run` to True, the existence of the cache can be rechecked at run() time. Default is True, but if the check takes too much time, you can set to False to inactivate the check. .. code:: python class SampleTask1(gokart.TaskOnKart): complete_check_at_run = False ================================================ FILE: docs/for_pandas.rst ================================================ For Pandas ========== Gokart has several features for Pandas. Pandas Type Config ------------------ Pandas has a feature that converts the type of column(s) automatically. This feature sometimes cause wrong result. To avoid unintentional type conversion of pandas, we can specify a column name to check the type of Task input and output in gokart. .. code:: python from typing import Any, Dict import pandas as pd import gokart # Please define a class which inherits `gokart.PandasTypeConfig`. class SamplePandasTypeConfig(gokart.PandasTypeConfig): @classmethod def type_dict(cls) -> Dict[str, Any]: return {'int_column': int} class SampleTask(gokart.TaskOnKart[pd.DataFrame]): def run(self): # [PandasTypeError] because expected type is `int`, but `str` is passed. df = pd.DataFrame(dict(int_column=['a'])) self.dump(df) This is useful when dataframe has nullable columns because pandas auto-conversion often fails in such case. Easy to Load DataFrame ---------------------- The :func:`~gokart.task.TaskOnKart.load` method is used to load input ``pandas.DataFrame``. .. code:: python def requires(self): return MakeDataFrameTask() def run(self): df = self.load() Please refer to :func:`~gokart.task.TaskOnKart.load`. Fail on empty DataFrame ----------------------- When the :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump` parameter is true, the :func:`~gokart.task.TaskOnKart.dump()` method raises :class:`~gokart.errors.EmptyDumpError` on trying to dump empty ``pandas.DataFrame``. .. code:: python import gokart class EmptyTask(gokart.TaskOnKart): def run(self): df = pd.DataFrame() self.dump(df) :: $ python main.py EmptyTask --fail-on-empty-dump true # EmptyDumpError $ python main.py EmptyTask # Task will be ran and outputs an empty dataframe Empty caches sometimes hide bugs and let us spend much time debugging. This feature notifies us some bugs (including wrong datasources) in the early stage. Please refer to :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump`. ================================================ FILE: docs/gokart.rst ================================================ gokart package ============== Submodules ---------- gokart.file\_processor module ----------------------------- .. automodule:: gokart.file_processor :members: :undoc-members: :show-inheritance: gokart.info module ------------------ .. automodule:: gokart.info :members: :undoc-members: :show-inheritance: gokart.parameter module ----------------------- .. automodule:: gokart.parameter :members: :undoc-members: :show-inheritance: gokart.run module ----------------- .. automodule:: gokart.run :members: :undoc-members: :show-inheritance: gokart.s3\_config module ------------------------ .. automodule:: gokart.s3_config :members: :undoc-members: :show-inheritance: gokart.target module -------------------- .. automodule:: gokart.target :members: :undoc-members: :show-inheritance: gokart.task module ------------------ .. automodule:: gokart.task :members: :undoc-members: :show-inheritance: gokart.workspace\_management module ----------------------------------- .. automodule:: gokart.workspace_management :members: :undoc-members: :show-inheritance: gokart.zip\_client module ------------------------- .. automodule:: gokart.zip_client :members: :undoc-members: :show-inheritance: Module contents --------------- .. automodule:: gokart :members: :undoc-members: :show-inheritance: ================================================ FILE: docs/index.rst ================================================ .. gokart documentation master file, created by sphinx-quickstart on Fri Jan 11 07:59:25 2019. You can adapt this file completely to your liking, but it should at least contain the root `toctree` directive. Welcome to gokart's documentation! ================================== Useful links: `GitHub `_ | `cookiecutter gokart `_ `Gokart `_ is a wrapper of the data pipeline library `luigi `_. Gokart solves "**reproducibility**", "**task dependencies**", "**constraints of good code**", and "**ease of use**" for Machine Learning Pipeline. Good thing about gokart ----------------------- Here are some good things about gokart. - The following data for each Task is stored separately in a pkl file with hash value - task output data - imported all module versions - task processing time - random seed in task - displayed log - all parameters set as class variables in the task - If change parameter of Task, rerun spontaneously. - The above file will be generated with a different hash value - The hash value of dependent task will also change and both will be rerun - Support GCS or S3 - The above output is exchanged between tasks as an intermediate file, which is memory-friendly - pandas.DataFrame type and column checking during I/O - Directory structure of saved files is automatically determined from structure of script - Seeds for numpy and random are automatically fixed - Can code while adhering to SOLID principles as much as possible - Tasks are locked via redis even if they run in parallel **These are all functions baptized for creating Machine Learning batches. Provides an excellent environment for reproducibility and team development.** Getting started ----------------- .. toctree:: :maxdepth: 2 intro_to_gokart tutorial User Guide ----------------- .. toctree:: :maxdepth: 2 task_on_kart task_parameters setting_task_parameters task_settings task_information logging slack_notification using_task_task_conflict_prevention_lock efficient_run_on_multi_workers for_pandas polars mypy_plugin API References -------------- .. toctree:: :maxdepth: 2 gokart Indices and tables ------------------- * :ref:`genindex` * :ref:`modindex` * :ref:`search` ================================================ FILE: docs/intro_to_gokart.rst ================================================ Intro To Gokart =============== Installation ------------ Within the activated Python environment, use the following command to install gokart. .. code:: sh pip install gokart Quickstart ---------- A minimal gokart tasks looks something like this: .. code:: python import gokart class Example(gokart.TaskOnKart[str]): def run(self): self.dump('Hello, world!') task = Example() output = gokart.build(task) print(output) ``gokart.build`` return the result of dump by ``gokart.TaskOnKart``. The example will output the following. .. code:: sh Hello, world! ``gokart`` records all the information needed for Machine Learning. By default, ``resources`` will be generated in the same directory as the script. .. code:: sh $ tree resources/ resources/ ├── __main__ │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl └── log ├── module_versions │   └── Example_8441c59b5ce0113396d53509f19371fb.txt ├── processing_time │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl ├── random_seed │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl ├── task_log │   └── Example_8441c59b5ce0113396d53509f19371fb.pkl └── task_params └── Example_8441c59b5ce0113396d53509f19371fb.pkl The result of dumping the task will be saved in the ``__name__`` directory. .. code:: python import pickle with open('resources/__main__/Example_8441c59b5ce0113396d53509f19371fb.pkl', 'rb') as f: print(pickle.load(f)) # Hello, world! That will be given hash value depending on the parameter of the task. This means that if you change the parameter of the task, the hash value will change, and change output file. This is very useful when changing parameters and experimenting. Please refer to :doc:`task_parameters` section for task parameters. Also see :doc:`task_on_kart` section for information on how to return this output destination. In addition, the following files are automatically saved as ``log``. - ``module_versions``: The versions of all modules that were imported when the script was executed. For reproducibility. - ``processing_time``: The execution time of the task. - ``random_seed``: This is random seed of python and numpy. For reproducibility in Machine Learning. Please refer to :doc:`task_settings` section. - ``task_log``: This is the output of the task logger. - ``task_params``: This is task's parameters. Please refer to :doc:`task_parameters` section. How to running task ------------------- Gokart has ``run`` and ``build`` methods for running task. Each has a different purpose. - ``gokart.run``: uses arguments on the shell. return retcode. - ``gokart.build``: uses inline code on jupyter notebook, IPython, and more. return task output. .. note:: It is not recommended to use ``gokart.run`` and ``gokart.build`` together in the same script. Because ``gokart.build`` will clear the contents of ``luigi.register``. It's the only way to handle duplicate tasks. gokart.run ~~~~~~~~~~ The :func:`~gokart.run` is running on shell. .. code:: python import gokart import luigi class SampleTask(gokart.TaskOnKart[str]): param = luigi.Parameter() def run(self): self.dump(self.param) gokart.run() .. code:: sh python sample.py SampleTask --local-scheduler --param=hello If you were to write it in Python, it would be the same as the following behavior. .. code:: python gokart.run(['SampleTask', '--local-scheduler', '--param=hello']) gokart.build ~~~~~~~~~~~~ The :func:`~gokart.build` is inline code. .. code:: python import gokart import luigi class SampleTask(gokart.TaskOnKart[str]): param: luigi.Parameter = luigi.Parameter() def run(self): self.dump(self.param) gokart.build(SampleTask(param='hello'), return_value=False) To output logs of each tasks, you can pass `~log_level` parameter to `~gokart.build` as follows: .. code:: python gokart.build(SampleTask(param='hello'), return_value=False, log_level=logging.DEBUG) This feature is very useful for running `~gokart` on jupyter notebook. When some tasks are failed, gokart.build raises GokartBuildError. If you have to get tracebacks, you should set `log_level` as `logging.DEBUG`. ================================================ FILE: docs/logging.rst ================================================ Logging ======= How to set up a common logger for gokart. Core settings ------------- Please write a configuration file similar to the following: :: # base.ini [core] logging_conf_file=./conf/logging.ini .. code:: python import gokart gokart.add_config('base.ini') Logger ini file --------------- It is the same as a general logging.ini file. :: [loggers] keys=root,luigi,luigi-interface,gokart,gokart.file_processor [handlers] keys=stderrHandler [formatters] keys=simpleFormatter [logger_root] level=INFO handlers=stderrHandler [logger_gokart] level=INFO handlers=stderrHandler qualname=gokart propagate=0 [logger_luigi] level=INFO handlers=stderrHandler qualname=luigi propagate=0 [logger_luigi-interface] level=INFO handlers=stderrHandler qualname=luigi-interface propagate=0 [logger_gokart.file_processor] level=CRITICAL handlers=stderrHandler qualname=gokart.file_processor [handler_stderrHandler] class=StreamHandler formatter=simpleFormatter args=(sys.stdout,) [formatter_simpleFormatter] format=[%(asctime)s][%(name)s][%(levelname)s](%(filename)s:%(lineno)s) %(message)s datefmt=%Y/%m/%d %H:%M:%S Please refer to `Python logging documentation `_ ================================================ FILE: docs/make.bat ================================================ @ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=. set BUILDDIR=_build if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% :end popd ================================================ FILE: docs/mypy_plugin.rst ================================================ [Experimental] Mypy plugin =========================== Mypy plugin provides type checking for gokart tasks using Mypy. This feature is experimental. How to use -------------- Configure Mypy to use this plugin by adding the following to your ``mypy.ini`` file: .. code:: ini [mypy] plugins = gokart.mypy:plugin or by adding the following to your ``pyproject.toml`` file: .. code:: toml [tool.mypy] plugins = ["gokart.mypy"] Then, run Mypy as usual. Examples -------- For example the following code linted by Mypy: .. code:: python import gokart import luigi class Foo(gokart.TaskOnKart): # NOTE: must all the parameters be annotated foo: int = luigi.IntParameter(default=1) bar: str = luigi.Parameter() Foo(foo=1, bar='2') # OK Foo(foo='1') # NG because foo is not int and bar is missing Mypy plugin checks TaskOnKart generic types. .. code:: python class SampleTask(gokart.TaskOnKart): str_task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter() int_task: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter() def requires(self): return dict(str=self.str_task, int=self.int_task) def run(self): s = self.load(self.str_task) # This type is inferred with "str" i = self.load(self.int_task) # This type is inferred with "int" SampleTask( str_task=StrTask(), # mypy ok int_task=StrTask(), # mypy error: Argument "int_task" to "StrTask" has incompatible type "StrTask"; expected "TaskOnKart[int] ) Configurations (only pyproject.toml) ----------------------------------- You can configure the Mypy plugin using the ``pyproject.toml`` file. The following options are available: .. code:: toml [tool.gokart-mypy] # If true, Mypy will raise an error if a task is missing required parameters. # This configuration causes an error when the parameters set by `luigi.Config()` # Default: false disallow_missing_parameters = true ================================================ FILE: docs/polars.rst ================================================ Polars Support ============== Gokart supports Polars DataFrames alongside pandas DataFrames for DataFrame-based file processors. This allows gradual migration from pandas to Polars or using both libraries simultaneously in your data pipelines. Installation ------------ Polars support is optional. Install it with: .. code:: bash pip install gokart[polars] Or install Polars separately: .. code:: bash pip install polars Basic Usage ----------- To use Polars DataFrames with gokart, specify ``dataframe_type='polars'`` when creating file processors: .. code:: python import polars as pl from gokart import TaskOnKart from gokart.file_processor import FeatherFileProcessor class MyPolarsTask(TaskOnKart[pl.DataFrame]): def output(self): return self.make_target( 'path/to/target.feather', processor=FeatherFileProcessor( store_index_in_feather=False, dataframe_type='polars' ) ) def run(self): df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) self.dump(df) Supported File Processors -------------------------- The following file processors support the ``dataframe_type`` parameter: CsvFileProcessor ^^^^^^^^^^^^^^^^ .. code:: python from gokart.file_processor import CsvFileProcessor # For Polars processor = CsvFileProcessor(sep=',', encoding='utf-8', dataframe_type='polars') # For pandas (default) processor = CsvFileProcessor(sep=',', encoding='utf-8', dataframe_type='pandas') # or simply processor = CsvFileProcessor(sep=',', encoding='utf-8') JsonFileProcessor ^^^^^^^^^^^^^^^^^ .. code:: python from gokart.file_processor import JsonFileProcessor # For Polars processor = JsonFileProcessor(orient='records', dataframe_type='polars') # For pandas (default) processor = JsonFileProcessor(orient='records', dataframe_type='pandas') ParquetFileProcessor ^^^^^^^^^^^^^^^^^^^^ .. code:: python from gokart.file_processor import ParquetFileProcessor # For Polars processor = ParquetFileProcessor( compression='gzip', dataframe_type='polars' ) # For pandas (default) processor = ParquetFileProcessor( compression='gzip', dataframe_type='pandas' ) FeatherFileProcessor ^^^^^^^^^^^^^^^^^^^^ .. code:: python from gokart.file_processor import FeatherFileProcessor # For Polars processor = FeatherFileProcessor( store_index_in_feather=False, dataframe_type='polars' ) # For pandas (default) processor = FeatherFileProcessor( store_index_in_feather=True, dataframe_type='pandas' ) .. note:: The ``store_index_in_feather`` parameter is pandas-specific and is ignored when using Polars. Using Pandas and Polars Together --------------------------------- Since projects often migrate from pandas gradually, gokart allows you to use both pandas and Polars simultaneously: .. code:: python import pandas as pd import polars as pl from gokart import TaskOnKart from gokart.file_processor import FeatherFileProcessor class PandasTask(TaskOnKart[pd.DataFrame]): """Task that outputs pandas DataFrame""" def output(self): return self.make_target( 'path/to/pandas_output.feather', processor=FeatherFileProcessor( store_index_in_feather=False, dataframe_type='pandas' ) ) def run(self): df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) self.dump(df) class PolarsTask(TaskOnKart[pl.DataFrame]): """Task that outputs Polars DataFrame""" def requires(self): return PandasTask() def output(self): return self.make_target( 'path/to/polars_output.feather', processor=FeatherFileProcessor( store_index_in_feather=False, dataframe_type='polars' ) ) def run(self): # Load pandas DataFrame and convert to Polars pandas_df = self.load() # Returns pandas DataFrame polars_df = pl.from_pandas(pandas_df) # Process with Polars result = polars_df.with_columns( (pl.col('a') * 2).alias('a_doubled') ) self.dump(result) Default Behavior ---------------- When ``dataframe_type`` is not specified, file processors default to ``'pandas'`` for backward compatibility: .. code:: python # These are equivalent processor = CsvFileProcessor(sep=',') processor = CsvFileProcessor(sep=',', dataframe_type='pandas') Important Notes --------------- **File Format Compatibility** Files created with Polars processors can be read by pandas processors and vice versa. The underlying file formats (CSV, JSON, Parquet, Feather) are library-agnostic. **Pandas-specific Features** Some pandas-specific features are not available with Polars: - ``store_index_in_feather`` parameter in ``FeatherFileProcessor`` is ignored for Polars - ``engine`` parameter in ``ParquetFileProcessor`` is ignored for Polars (uses Polars' default) **Error Handling** If you specify ``dataframe_type='polars'`` but Polars is not installed, you'll get an ``ImportError`` with installation instructions: .. code:: text ImportError: polars is required for dataframe_type='polars'. Install with: pip install polars Migration Strategy ------------------ Recommended approach for migrating from pandas to Polars: 1. Install Polars: ``pip install gokart[polars]`` 2. Create new tasks using ``dataframe_type='polars'`` 3. Keep existing tasks with ``dataframe_type='pandas'`` or default behavior 4. Gradually migrate tasks as needed 5. Convert DataFrames between libraries using ``pl.from_pandas()`` and ``df.to_pandas()`` when necessary ================================================ FILE: docs/requirements.txt ================================================ Sphinx gokart sphinx-rtd-theme ================================================ FILE: docs/setting_task_parameters.rst ================================================ ============================ Setting Task Parameters ============================ There are several ways to set task parameters. - Set parameter from command line - Set parameter at config file - Set parameter at upstream task - Inherit parameter from other task Set parameter from command line ================================== .. code:: sh python main.py sample.SomeTask --SomeTask-param=Hello Parameter of each task can be set as a command line parameter in ``--[task name]-[parameter name]=[value]`` format. Set parameter at config file ================================== :: [sample.SomeTask] param = Hello Above config file (``config.ini``) must be read before ``gokart.run()`` as the following code: .. code:: python if __name__ == '__main__': gokart.add_config('./conf/config.ini') gokart.run() It can also be loaded from environment variable as the following code: :: [sample.SomeTask] param=${PARAMS} [TaskOnKart] workspace_directory=${WORKSPACE_DIRECTORY} The advantages of using environment variables are 1) important information will not be logged 2) common settings can be used. Set parameter at upstream task ================================== Parameters can be set at the upstream task, as in a typical pipeline. .. code:: python class UpstreamTask(gokart.TaskOnKart): def requires(self): return dict(sometask=SomeTask(param='Hello')) Inherit parameter from other task ================================== Parameter values can be inherited from other task using ``@inherits_config_params`` decorator. .. code:: python class MasterConfig(luigi.Config): param: luigi.Parameter = luigi.Parameter() param2: luigi.Parameter = luigi.Parameter() @inherits_config_params(MasterConfig) class SomeTask(gokart.TaskOnKart): param: luigi.Parameter = luigi.Parameter() This is useful when multiple tasks has the same parameter. In the above example, parameter settings of ``MasterConfig`` will be inherited to all tasks decorated with ``@inherits_config_params(MasterConfig)`` as ``SomeTask``. Note that only parameters which exist in both ``MasterConfig`` and ``SomeTask`` will be inherited. In the above example, ``param2`` will not be available in ``SomeTask``, since ``SomeTask`` does not have ``param2`` parameter. .. code:: python class MasterConfig(luigi.Config): param: luigi.Parameter = luigi.Parameter() param2: luigi.Parameter = luigi.Parameter() @inherits_config_params(MasterConfig, parameter_alias={'param2': 'param3'}) class SomeTask(gokart.TaskOnKart): param3: luigi.Parameter = luigi.Parameter() You may also set a parameter name alias by setting ``parameter_alias``. ``parameter_alias`` must be a dictionary of key: inheriting task's parameter name, value: decorating task's parameter name. In the above example, ``SomeTask.param3`` will be set to same value as ``MasterConfig.param2``. ================================================ FILE: docs/slack_notification.rst ================================================ Slack notification ========================= Prerequisites ------------- Prepare following environmental variables: .. code:: sh export SLACK_TOKEN=xoxb-your-token // should use token starts with "xoxb-" (bot token is preferable) export SLACK_CHANNEL=channel-name // not "#channel-name", just "channel-name" A Slack bot token can obtain from `slack app document `_. A bot token needs following scopes: - `channels:read` - `chat:write` - `files:write` More about scopes are `slack scopes document `_. Implement Slack notification ---------------------------- Write following codes pass arguments to your gokart workflow. .. code:: python cmdline_args = sys.argv[1:] if 'SLACK_CHANNEL' in os.environ: cmdline_args.append(f'--SlackConfig-channel={os.environ["SLACK_CHANNEL"]}') if 'SLACK_TO_USER' in os.environ: cmdline_args.append(f'--SlackConfig-to-user={os.environ["SLACK_TO_USER"]}') gokart.run(cmdline_args) ================================================ FILE: docs/task_information.rst ================================================ Task Information ================ There are 6 ways to print the significant parameters and state of the task and its dependencies. * 1. One is to use luigi module. See `luigi.tools.deps_tree module `_ for details. * 2. ``task-info`` option of ``gokart.run()``. * 3. ``make_task_info_as_tree_str()`` will return significant parameters and dependency tree as str. * 4. ``make_task_info_as_table()`` will return significant parameter and dependent tasks as pandas.DataFrame table format. * 5. ``dump_task_info_table()`` will dump the result of ``make_task_info_as_table()`` to a file. * 6. ``dump_task_info_tree()`` will dump the task tree object (TaskInfo) to a pickle file. This document will cover 2~6. 2. task-info option of gokart.run() -------------------------------------------- On CLI ~~~~~~ An example implementation could be like: .. code:: python # main.py import gokart if __name__ == '__main__': gokart.run() .. code:: sh $ python main.py \ TaskB \ --param=Hello \ --local-scheduler \ --tree-info-mode=all \ --tree-info-output-path=tree_all.txt The ``--tree-info-mode`` option accepts "simple" and "all", and a task information is saved in ``--tree-info-output-path``. when "simple" is passed, it outputs the states and the unique ids of tasks. An example output is as follows: .. code:: text └─-(COMPLETE) TaskB[09fe5591ef2969ce7443c419a3b19e5d] └─-(COMPLETE) TaskA[2549878535c070fb6c3cd4061bdbbcff] When "all" is passed, it outputs the states, the unique ids, the significant parameters, the execution times and the task logs of tasks. An example output is as follows: .. code:: text └─-(COMPLETE) TaskB[09fe5591ef2969ce7443c419a3b19e5d](parameter={'workspace_directory': './resources/', 'local_temporary_directory': './resources/tmp/', 'param': 'Hello'}, output=['./resources/output_of_task_b_09fe5591ef2969ce7443c419a3b19e5d.pkl'], time=0.002290010452270508s, task_log={}) └─-(COMPLETE) TaskA[2549878535c070fb6c3cd4061bdbbcff](parameter={'workspace_directory': './resources/', 'local_temporary_directory': './resources/tmp/', 'param': 'called by TaskB'}, output=['./resources/output_of_task_a_2549878535c070fb6c3cd4061bdbbcff.pkl'], time=0.0009829998016357422s, task_log={}) 3. make_task_info_as_tree_str() ----------------------------------------- ``gokart.tree.task_info.make_task_info_as_tree_str()`` will return a tree dependency tree as a str. .. code:: python from gokart.tree.task_info import make_task_info_as_tree_str make_task_info_as_tree_str(task, ignore_task_names) # Parameters # ---------- # - task: TaskOnKart # Root task. # - details: bool # Whether or not to output details. # - abbr: bool # Whether or not to simplify tasks information that has already appeared. # - ignore_task_names: Optional[List[str]] # List of task names to ignore. # Returns # ------- # - tree_info : str # Formatted task dependency tree. example .. code:: python import luigi import gokart class TaskA(gokart.TaskOnKart[str]): param = luigi.Parameter() def run(self): self.dump(f'{self.param}') class TaskB(gokart.TaskOnKart[str]): task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter() def run(self): task = self.load('task') self.dump(task + ' taskB') class TaskC(gokart.TaskOnKart[str]): task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter() def run(self): task = self.load('task') self.dump(task + ' taskC') class TaskD(gokart.TaskOnKart): task1: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter() task2: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter() def run(self): task = [self.load('task1'), self.load('task2')] self.dump(','.join(task)) .. code:: python task = TaskD( task1=TaskD( task1=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))), # same task task2=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))) # same task ), task2=TaskD( task1=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))), # same task task2=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))) # same task ) ) print(gokart.make_task_info_as_tree_str(task)) .. code:: sh └─-(PENDING) TaskD[187ff82158671283e127e2e1f7c9c095] |--(PENDING) TaskD[ca9e943ce049e992b371898c0578784e] # duplicated TaskD | |--(PENDING) TaskD[1cc9f9fc54a56614f3adef74398684f4] # duplicated TaskD | | |--(PENDING) TaskC[dce3d8e7acaf1bb9731fb4f2ae94e473] | | | └─-(PENDING) TaskA[be65508b556dd3752359b4246791413d] | | └─-(PENDING) TaskC[de39593d31490aba3cdca3c650432504] | | └─-(PENDING) TaskB[bc2f7d6cdd6521cc116c35f0f144eed3] | | └─-(PENDING) TaskA[5a824f7d232eb69d46f0ac6bbd93b565] | └─-(PENDING) TaskD[1cc9f9fc54a56614f3adef74398684f4] | └─- ... └─-(PENDING) TaskD[ca9e943ce049e992b371898c0578784e] └─- ... In the above example, the sub-trees already shown is omitted. This can be disabled by passing ``False`` to ``abbr`` flag: .. code:: python print(make_task_info_as_tree_str(task, abbr=False)) 4. make_task_info_as_table() -------------------------------- ``gokart.tree.task_info.make_task_info_as_table()`` will return a table containing the information of significant parameters and dependent tasks as a pandas DataFrame. This table contains `task name`, `cache unique id`, `cache file path`, `task parameters`, `task processing time`, `completed flag`, and `task log`. .. code:: python from gokart.tree.task_info import make_task_info_as_table make_task_info_as_table(task, ignore_task_names) # """Return a table containing information about dependent tasks. # # Parameters # ---------- # - task: TaskOnKart # Root task. # - ignore_task_names: Optional[List[str]] # List of task names to ignore. # Returns # ------- # - task_info_table : pandas.DataFrame # Formatted task dependency table. # """ 5. dump_task_info_table() ----------------------------------------- ``gokart.tree.task_info.dump_task_info_table()`` will dump the task_info table made at ``make_task_info_as_table()`` to a file. .. code:: python from gokart.tree.task_info import dump_task_info_table dump_task_info_table(task, task_info_dump_path, ignore_task_names) # Parameters # ---------- # - task: TaskOnKart # Root task. # - task_info_dump_path: str # Output target file path. Path destination can be `local`, `S3`, or `GCS`. # File extension can be any type that gokart file processor accepts, including `csv`, `pickle`, or `txt`. # See `TaskOnKart.make_target module ` for details. # - ignore_task_names: Optional[List[str]] # List of task names to ignore. # Returns # ------- # None 6. dump_task_info_tree() ----------------------------------------- ``gokart.tree.task_info.dump_task_info_tree()`` will dump the task tree object (TaskInfo) to a pickle file. .. code:: python from gokart.tree.task_info import dump_task_info_tree dump_task_info_tree(task, task_info_dump_path, ignore_task_names, use_unique_id) # Parameters # ---------- # - task: TaskOnKart # Root task. # - task_info_dump_path: str # Output target file path. Path destination can be `local`, `S3`, or `GCS`. # File extension must be '.pkl'. # - ignore_task_names: Optional[List[str]] # List of task names to ignore. # - use_unique_id: bool = True # Whether to use unique id to dump target file. Default is True. # Returns # ------- # None Task Logs --------- To output extra information of tasks by ``tree-info``, the member variable :attr:`~gokart.task.TaskOnKart.task_log` of ``TaskOnKart`` keeps any information as a dictionary. For instance, the following code runs, .. code:: python import gokart class SampleTaskLog(gokart.TaskOnKart): def run(self): # Add some logs. self.task_log['sample key'] = 'sample value' if __name__ == '__main__': SampleTaskLog().run() gokart.run([ '--tree-info-mode=all', '--tree-info-output-path=sample_task_log.txt', 'SampleTaskLog', '--local-scheduler']) the output could be like: .. code:: text └─-(COMPLETE) SampleTaskLog[...](..., task_log={'sample key': 'sample value'}) Delete Unnecessary Output Files -------------------------------- To delete output files which are not necessary to run a task, add option ``--delete-unnecessary-output-files``. This option is supported only when a task outputs files in local storage not S3 for now. ================================================ FILE: docs/task_on_kart.rst ================================================ TaskOnKart ========== ``TaskOnKart`` inherits ``luigi.Task``, and has functions to make it easy to define tasks. Please see `luigi documentation `_ for details of ``luigi.Task``. Please refer to :doc:`intro_to_gokart` section and :doc:`tutorial` section. Outline -------- How ``TaskOnKart`` helps to define a task looks like: .. code:: python import luigi import gokart class TaskA(gokart.TaskOnKart[str]): param: luigi.Parameter = luigi.Parameter() def output(self): return self.make_target('output_of_task_a.pkl') def run(self): results = f'param={self.param}' self.dump(results) class TaskB(gokart.TaskOnKart[str]): param: luigi.Parameter = luigi.Parameter() def requires(self): return TaskA(param='world') def output(self): # `make_target` makes an instance of `luigi.Target`. # This infers the output format and the destination of an output objects. # The target file path is # '{self.workspace_directory}/output_of_task_b_{self.make_unique_id()}.pkl'. return self.make_target('output_of_task_b.pkl') def run(self): # `load` loads input data. In this case, this loads the output of `TaskA`. output_of_task_a = self.load() results = f'Task A: {output_of_task_a}\nTaskB: param={self.param}' # `dump` writes `results` to the file path of `self.output()`. self.dump(results) if __name__ == '__main__': print(gokart.build([TaskB(param='Hello')])) The result of this script will look like this .. code:: sh Task A: param=world Task B: param=Hello The results are obtained as a pipeline by linking A and B. TaskOnKart.make_target ---------------------- The :func:`~gokart.task.TaskOnKart.make_target` method is used to make an instance of ``Luigi.Target``. For instance, an example implementation could be as follows: .. code:: python def output(self): return self.make_target('file_name.pkl') The ``make_target`` method adds ``_{self.make_unique_id()}`` to the file name as suffix. In this case, the target file path is ``{self.workspace_directory}/file_name_{self.make_unique_id()}.pkl``. It is also possible to specify a file format other than pkl. The supported file formats are as follows: - .pkl - .txt - .csv - .tsv - .gz - .json - .xml - .npz - .parquet - .feather - .png - .jpg - .ini If dump something other than the above, can use :func:`~gokart.TaskOnKart.make_model_target`. Please refer to :func:`~gokart.task.TaskOnKart.make_target` and described later Advanced Features section. .. note:: By default, file path is inferred from "__name__" of the script, so ``output`` method can be omitted. Please refer to :doc:`tutorial` section. .. note:: When using `.feather`, index will be converted to column at saving and restored to index at loading. If you don't prefere saving index, set `store_index_in_feather=False` parameter at `gokart.target.make_target()`. .. note:: When you set `serialized_task_definition_check=True`, the task will rerun when you modify the scripts of the task. Please note that the scripts outside the class are not considered. TaskOnKart.load ---------------- The :func:`~gokart.task.TaskOnKart.load` method is used to load input data. For instance, an example implementation could be as follows: .. code:: python def requires(self): return TaskA(param='called by TaskB') def run(self): # `load` loads input data. In this case, this loads the output of `TaskA`. output_of_task_a = self.load() In the case that a task requires 2 or more tasks as input, the return value of this method has the same structure with `requires` value. For instance, an example implementation that `requires` returns a dictionary of tasks could be like follows: .. code:: python def requires(self): return dict(a=TaskA(), b=TaskB()) def run(self): data = self.load() # returns dict(a=self.load('a'), b=self.load('b')) The `load` method loads individual task input by passing a key of an input dictionary as follows: .. code:: python def run(self): data_a = self.load('a') data_b = self.load('b') As an alternative, the `load` method loads individual task input by passing an instance of TaskOnKart as follows: .. code:: python def run(self): data_a = self.load(TaskA()) data_b = self.load(TaskB()) We can also omit the :func:`~gokart.task.TaskOnKart.requires` and write the task used by :func:`~gokart.parameter.TaskInstanceParameter`. Also please refer to :func:`~gokart.task.TaskOnKart.load`, :doc:`task_parameters`, and described later Advanced Features section. TaskOnKart.dump ---------------- The :func:`~gokart.task.TaskOnKart.dump` method is used to dump results of tasks. For instance, an example implementation could be as follows: .. code:: python def output(self): return self.make_target('output.pkl') def run(self): results = do_something(self.load()) self.dump(results) In the case that a task has 2 or more output, it is possible to specify output target by passing a key of dictionary like follows: .. code:: python def output(self): return dict(a=self.make_target('output_a.pkl'), b=self.make_target('output_b.pkl')) def run(self): a_data = do_something_a(self.load()) b_data = do_something_b(self.load()) self.dump(a_data, 'a') self.dump(b_data, 'b') Please refer to :func:`~gokart.task.TaskOnKart.dump`. Advanced Features --------------------- TaskOnKart.load_generator ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The :func:`~gokart.task.TaskOnKart.load_generator` method is used to load input data with generator. For instance, an example implementation could be as follows: .. code:: python def requires(self): return TaskA(param='called by TaskB') def run(self): for data in self.load_generator(): any_process(data) Usage is the same as `TaskOnKart.generator`. `load_generator` reads the divided file into iterations. It's effective when can't read all data to memory, because `load_generator` doesn't load all files at once. Please refer to :func:`~gokart.task.TaskOnKart.load_generator`. TaskOnKart.make_model_target ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ The :func:`~gokart.task.TaskOnKart.make_model_target` method is used to dump for non supported file types. .. code:: python import gensim class TrainWord2Vec(gokart.TaskOnKart[Word2VecResult]): def output(self): # please use 'zip'. return self.make_model_target( 'model.zip', save_function=gensim.model.Word2Vec.save, load_function=gensim.model.Word2Vec.load) def run(self): # -- train word2vec --- word2vec = train_word2vec() self.dump(word2vec) It is dumped and zipped with ``gensim.model.Word2Vec.save``. Please refer to :func:`~gokart.task.TaskOnKart.make_model_target`. TaskOnKart.fail_on_empty_dump ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Please refer to :doc:`for_pandas`. TaskOnKart.should_dump_supplementary_log_files ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. Default is True. Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used. Dump csv with encoding ~~~~~~~~~~~~~~~~~~~~~~~ You can dump csv file by implementing `Task.output()` method as follows: .. code:: python def output(self): return self.make_target('file_name.csv') By default, csv file is dumped with `utf-8` encoding. If you want to dump csv file with other encodings, you can use `encoding` parameter as follows: .. code:: python from gokart.file_processor import CsvFileProcessor def output(self): return self.make_target('file_name.csv', processor=CsvFileProcessor(encoding='cp932')) # This will dump csv as 'cp932' which is used in Windows. Cache output in memory instead of dumping to files ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ You can use :class:`~InMemoryTarget` to cache output in memory instead of dumping to files by calling :func:`~gokart.target.make_in_memory_target`. Please note that :class:`~InMemoryTarget` is an experimental feature. .. code:: python from gokart.in_memory.target import make_in_memory_target def output(self): unique_id = self.make_unique_id() if use_unique_id else None # TaskLock is not supported in InMemoryTarget, so it's dummy task_lock_params = make_task_lock_params( file_path='dummy_path', unique_id=unique_id, redis_host=None, redis_port=None, redis_timeout=self.redis_timeout, raise_task_lock_exception_on_collision=False, ) return make_in_memory_target('dummy_path', task_lock_params, unique_id) ================================================ FILE: docs/task_parameters.rst ================================================ ================= Task Parameters ================= Luigi Parameter ================ We can set parameters for tasks. Also please refer to :doc:`task_settings` section. .. code:: python class Task(gokart.TaskOnKart): param_a: luigi.Parameter = luigi.Parameter() param_c: luigi.ListParameter = luigi.ListParameter() param_d: luigi.IntParameter = luigi.IntParameter(default=1) Please refer to `luigi document `_ for a list of parameter types. Gokart Parameter ================ There are also parameters provided by gokart. - gokart.TaskInstanceParameter - gokart.ListTaskInstanceParameter - gokart.ExplicitBoolParameter gokart.TaskInstanceParameter -------------------------------- The :func:`~gokart.parameter.TaskInstanceParameter` executes a task using the results of a task as dynamic parameters. .. code:: python class TaskA(gokart.TaskOnKart[str]): def run(self): self.dump('Hello') class TaskB(gokart.TaskOnKart[str]): require_task: gokart.TaskInstanceParameter = gokart.TaskInstanceParameter() def requires(self): return self.require_task def run(self): task_a = self.load() self.dump(','.join([task_a, 'world'])) task = TaskB(require_task=TaskA()) print(gokart.build(task)) # Hello,world Helps to create a pipeline. gokart.ListTaskInstanceParameter ------------------------------------- The :func:`~gokart.parameter.ListTaskInstanceParameter` is list of TaskInstanceParameter. gokart.ExplicitBoolParameter ----------------------------------- The :func:`~gokart.parameter.ExplicitBoolParameter` is parameter for explicitly specified value. ``luigi.BoolParameter`` already has "explicit parsing" feature, but also still has implicit behavior like follows. :: $ python main.py Task --param # param will be set as True $ python main.py Task # param will be set as False ``ExplicitBoolParameter`` solves these problems on parameters from command line. gokart.SerializableParameter ---------------------------- The :func:`~gokart.parameter.SerializableParameter` is a parameter for any object that can be serialized and deserialized. This parameter is particularly useful when you want to pass a complex object or a set of parameters to a task. The object must implement the following methods: - ``gokart_serialize``: Serialize the object to a string. This serialized string must uniquely identify the object to enable task caching. Note that it is not required for deserialization. - ``gokart_deserialize``: Deserialize the object from a string, typically used for CLI arguments. Example ^^^^^^^ .. code-block:: python import json from dataclasses import dataclass import gokart @dataclass(frozen=True) class Config: foo: int # The `bar` field does not affect the result of the task. # Similar to `luigi.Parameter(significant=False)`. bar: str def gokart_serialize(self) -> str: # Serialize only the `foo` field since `bar` is irrelevant for caching. return json.dumps({'foo': self.foo}) @classmethod def gokart_deserialize(cls, s: str) -> 'Config': # Deserialize the object from the provided string. return cls(**json.loads(s)) class DummyTask(gokart.TaskOnKart): config: gokart.SerializableParameter[Config] = gokart.SerializableParameter(object_type=Config) def run(self): # Save the `config` object as part of the task result. self.dump(self.config) ================================================ FILE: docs/task_settings.rst ================================================ Task Settings ============= Task settings. Also please refer to :doc:`task_parameters` section. Directory to Save Outputs ------------------------- We can use both a local directory and the S3 to save outputs. If you would like to use local directory, please set a local directory path to :attr:`~gokart.task.TaskOnKart.workspace_directory`. Please refer to :doc:`task_parameters` for how to set it up. It is recommended to use the config file since it does not change much. :: # base.ini [TaskOnKart] workspace_directory=${TASK_WORKSPACE_DIRECTORY} .. code:: python # main.py import gokart gokart.add_config('base.ini') To use the S3 or GCS repository, please set the bucket path as ``s3://{YOUR_REPOSITORY_NAME}`` or ``gs://{YOUR_REPOSITORY_NAME}`` respectively. If use S3 or GCS, please set credential information to Environment Variables. .. code:: sh # S3 export AWS_ACCESS_KEY_ID='~~~' # AWS access key export AWS_SECRET_ACCESS_KEY='~~~' # AWS secret access key # GCS export GCS_CREDENTIAL='~~~' # GCS credential export DISCOVER_CACHE_LOCAL_PATH='~~~' # The local file path of discover api cache. Rerun task ---------- There are times when we want to rerun a task, such as when change script or on batch. Please use the ``rerun`` parameter or add an arbitrary parameter. When set rerun as follows: .. code:: python # rerun TaskA gokart.build(Task(rerun=True)) When used from an argument as follows: .. code:: python # main.py class Task(gokart.TaskOnKart[str]): def run(self): self.dump('hello') .. code:: sh python main.py Task --local-scheduler --rerun ``rerun`` parameter will look at the dependent tasks up to one level. Example: Suppose we have a straight line pipeline composed of TaskA, TaskB and TaskC, and TaskC is an endpoint of this pipeline. We also suppose that all the tasks have already been executed. - TaskA(rerun=True) -> TaskB -> TaskC # not rerunning - TaskA -> TaskB(rerun=True) -> TaskC # rerunning TaskB and TaskC This is due to the way intermediate files are handled. ``rerun`` parameter is ``significant=False``, it does not affect the hash value. It is very important to understand this difference. If you want to change the parameter of TaskA and rerun TaskB and TaskC, recommend adding an arbitrary parameter. .. code:: python class TaskA(gokart.TaskOnKart): __version: luigi.IntParameter = luigi.IntParameter(default=1) If the hash value of TaskA will change, the dependent tasks (in this case, TaskB and TaskC) will rerun. Fix random seed --------------- Every task has a parameter named :attr:`~gokart.task.TaskOnKart.fix_random_seed_methods` and :attr:`~gokart.task.TaskOnKart.fix_random_seed_value`. This can be used to fix the random seed. .. code:: python import gokart import random import numpy import torch class Task(gokart.TaskOnKart[dict[str, Any]]): def run(self): x = [random.randint(0, 100) for _ in range(0, 10)] y = [np.random.randint(0, 100) for _ in range(0, 10)] z = [torch.randn(1).tolist()[0] for _ in range(0, 5)] self.dump({'random': x, 'numpy': y, 'torch': z}) gokart.build( Task( fix_random_seed_methods=[ "random.seed", "numpy.random.seed", "torch.random.manual_seed"], fix_random_seed_value=57)) :: # //--- The output is as follows every time. --- # {'random': [65, 41, 61, 37, 55, 81, 48, 2, 94, 21], # 'numpy': [79, 86, 5, 22, 79, 98, 56, 40, 81, 37], 'torch': []} # 'torch': [0.14460121095180511, -0.11649507284164429, # 0.6928958296775818, -0.916053831577301, 0.7317505478858948]} This will be useful for using Machine Learning Libraries. ================================================ FILE: docs/tutorial.rst ================================================ Tutorial ======== Also please refer to :doc:`intro_to_gokart` section. 1, Make gokart project ---------------------- Create a project using `cookiecutter-gokart `_. .. code:: sh cookiecutter https://github.com/m3dev/cookiecutter-gokart # project_name [project_name]: example # package_name [package_name]: gokart_example # python_version [3.7.0]: # author [your name]: m3dev # package_description [What's this project?]: gokart example # license [MIT License]: You will have a directory tree like following: .. code:: sh tree example/ example/ ├── Dockerfile ├── README.md ├── conf │   ├── logging.ini │   └── param.ini ├── gokart_example │   ├── __init__.py │   ├── model │   │   ├── __init__.py │   │   └── sample.py │   └── utils │   └── template.py ├── main.py ├── pyproject.toml └── test ├── __init__.py └── unit_test └── test_sample.py 2, Running sample task ---------------------- Let's run the first task. .. code:: sh python main.py gokart_example.Sample --local-scheduler The results are stored in resources directory. .. code:: sh tree resources resources/ ├── gokart_example │   └── model │   └── sample │   └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl └── log ├── module_versions │   └── Sample_cdf55a3d6c255d8c191f5f472da61f99.txt ├── processing_time │   └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl ├── random_seed │   └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl ├── task_log │   └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl └── task_params └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl Please refer to :doc:`intro_to_gokart` for output .. note:: It is better to use poetry in terms of the module version. Please refer to `poetry document `_ .. code:: sh poetry lock poetry run python main.py gokart_example.Sample --local-scheduler If want to stabilize it further, please use docker. .. code:: sh docker build -t sample . docker run -it sample "python main.py gokart_example.Sample --local-scheduler" 3, Check result --------------- Check the output. .. code:: python with open('resources/gokart_example/model/sample/Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl', 'rb') as f: print(pickle.load(f)) # sample output 4, Run unittest ------------------ It is important to run unittest before and after modifying the code. .. code:: sh python -m unittest discover -s ./test/unit_test/ . ---------------------------------------------------------------------- Ran 1 test in 0.001s OK 5, Create Task -------------- Writing gokart-like tasks. Modify ``example/gokart_example/model/sample.py`` as follows: .. code:: python from logging import getLogger import gokart from gokart_example.utils.template import GokartTask logger = getLogger(__name__) class Sample(GokartTask): def run(self): self.dump('sample output') class StringToSplit(GokartTask): """Like the function to divide received data by spaces.""" task: gokart.TaskInstanceParameter = gokart.TaskInstanceParameter() def run(self): sample = self.load('task') self.dump(sample.split(' ')) class Main(GokartTask): """Endpoint task.""" def requires(self): return StringToSplit(task=Sample()) Added ``Main`` and ``StringToSplit``. ``StringToSplit`` is a function-like task that loads the result of an arbitrary task and splits it by spaces. ``Main`` is injecting ``Sample`` into ``StringToSplit``. It like Endpoint. Let’s run the ``Main`` task. .. code:: sh python main.py gokart_example.Main --local-scheduler Please take a look at the logger output at this time. :: ===== Luigi Execution Summary ===== Scheduled 3 tasks of which: * 1 complete ones were encountered: - 1 gokart_example.Sample(...) * 2 ran successfully: - 1 gokart_example.Main(...) - 1 gokart_example.StringToSplit(...) This progress looks :) because there were no failed tasks or missing dependencies ===== Luigi Execution Summary ===== As the log shows, ``Sample`` has been executed once, so the ``cache`` will be used. The only things that worked were ``Main`` and ``StringToSplit``. The output will look like the following, with the result in ``StringToSplit_b8a0ce6c972acbd77eae30f35da4307e.pkl``. :: tree resources/ resources/ ├── gokart_example │   └── model │   └── sample │   ├── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl │   └── StringToSplit_b8a0ce6c972acbd77eae30f35da4307e.pkl ... .. code:: python with open('resources/gokart_example/model/sample/StringToSplit_b8a0ce6c972acbd77eae30f35da4307e.pkl', 'rb') as f: print(pickle.load(f)) # ['sample', 'output'] It was able to move the added task. 6, Rerun Task ------------- Finally, let's rerun the task. There are two ways to rerun a task. Change the ``rerun parameter`` or ``parameters of the dependent tasks``. ``gokart.TaskOnKart`` can set ``rerun parameter`` for each task like following: .. code:: python class Main(GokartTask): rerun=True def requires(self): return StringToSplit(task=Sample(rerun=True), rerun=True) OR Add new parameter on dependent tasks like following: .. code:: python class Sample(GokartTask): version: luigi.IntParameter = luigi.IntParameter(default=1) def run(self): self.dump('sample output version {self.version}') In both cases, all tasks will be rerun. The difference is hash value given to output files. The reurn parameter has no effect on the hash value. So it will be rerun with the same hash value. In the second method, ``version parameter`` is added to the ``Sample`` task. This parameter will change the hash value of ``Sample`` and generate another output file. And the dependent task, ``StringToSplit``, will also have a different hash value, and rerun. Please refer to :doc:`task_settings` for details. Please try rerunning task at hand:) Feature ------- This is the end of the gokart tutorial. The tutorial is an introduction to some of the features. There are still more useful features. Please See :doc:`task_on_kart` section, :doc:`for_pandas` section and :doc:`task_parameters` section for more useful features of the task. Have a good gokart life. ================================================ FILE: docs/using_task_task_conflict_prevention_lock.rst ================================================ Task conflict prevention lock ========================= If there is a possibility of multiple worker nodes executing the same task, task cache conflict may happen. Specifically, while node A is loading the cache of a task, node B may be writing to it. This can lead to reading an inappropriate data and other unwanted behaviors. The redis lock introduced in this page is a feature to prevent such cache collisions. Requires -------- You need to install `redis `_ for using this advanced feature. How to use ----------- 1. Set up a redis server at somewhere accessible from gokart/luigi jobs. e.g. Following script will run redis at your localhost. .. code:: bash $ redis-server 2. Set redis server hostname and port number as parameters of gokart.TaskOnKart(). You can set it by adding ``--redis-host=[your-redis-localhost] --redis-port=[redis-port-number]`` options to gokart python script. e.g. .. code:: bash python main.py sample.SomeTask --local-scheduler --redis-host=localhost --redis-port=6379 Alternatively, you may set parameters at config file. e.g. .. code:: [TaskOnKart] redis_host=localhost redis_port=6379 3. Done With the above configuration, all tasks that inherits gokart.TaskOnKart will ask the redis server if any other node is not trying to access the same cache file at the same time whenever they access the file with dump or load. ================================================ FILE: examples/gokart_notebook_example.ipynb ================================================ { "cells": [ { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "gokart 1.0.2\n" ] } ], "source": [ "!pip list | grep gokart" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "import gokart\n", "import luigi" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Examples of using gokart at jupyter notebook" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Basic Usage\n", "This is a very basic usage, just to dump a run result of ExampleTaskA." ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "example_2\n" ] } ], "source": [ "class ExampleTaskA(gokart.TaskOnKart):\n", " param = luigi.Parameter()\n", " int_param = luigi.IntParameter(default=2)\n", "\n", " def run(self):\n", " self.dump(f'DONE {self.param}_{self.int_param}')\n", "\n", " \n", "task_a = ExampleTaskA(param='example')\n", "output = gokart.build(task=task_a)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make tasks dependencies with `requires()`\n", "ExampleTaskB is dependent on ExampleTaskC and ExampleTaskD. They are defined in `ExampleTaskB.requires()`." ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DONE example_TASKC_TASKD\n" ] } ], "source": [ "class ExampleTaskC(gokart.TaskOnKart):\n", " def run(self):\n", " self.dump('TASKC')\n", " \n", "class ExampleTaskD(gokart.TaskOnKart):\n", " def run(self):\n", " self.dump('TASKD')\n", "\n", "class ExampleTaskB(gokart.TaskOnKart):\n", " param = luigi.Parameter()\n", "\n", " def requires(self):\n", " return dict(task_c=ExampleTaskC(), task_d=ExampleTaskD())\n", "\n", " def run(self):\n", " task_c = self.load('task_c')\n", " task_d = self.load('task_d')\n", " self.dump(f'DONE {self.param}_{task_c}_{task_d}')\n", " \n", "task_b = ExampleTaskB(param='example')\n", "output = gokart.build(task=task_b)\n", "print(output)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Make tasks dependencies with TaskInstanceParameter\n", "The dependencies are same as previous example, however they are defined at the outside of the task instead of defied at `ExampleTaskB.requires()`." ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "DONE example_TASKC_TASKD\n" ] } ], "source": [ "class ExampleTaskC(gokart.TaskOnKart):\n", " def run(self):\n", " self.dump('TASKC')\n", " \n", "class ExampleTaskD(gokart.TaskOnKart):\n", " def run(self):\n", " self.dump('TASKD')\n", "\n", "class ExampleTaskB(gokart.TaskOnKart):\n", " param = luigi.Parameter()\n", " task_1 = gokart.TaskInstanceParameter()\n", " task_2 = gokart.TaskInstanceParameter()\n", "\n", " def requires(self):\n", " return dict(task_1=self.task_1, task_2=self.task_2) # required tasks are decided from the task parameters `task_1` and `task_2`\n", "\n", " def run(self):\n", " task_1 = self.load('task_1')\n", " task_2 = self.load('task_2')\n", " self.dump(f'DONE {self.param}_{task_1}_{task_2}')\n", " \n", "task_b = ExampleTaskB(param='example', task_1=ExampleTaskC(), task_2=ExampleTaskD()) # Dependent tasks are defined here\n", "output = gokart.build(task=task_b)\n", "print(output)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.8.8 64-bit ('3.8.8': pyenv)", "name": "python388jvsc74a57bd026997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.8" }, "metadata": { "interpreter": { "hash": "26997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca" } }, "orig_nbformat": 3 }, "nbformat": 4, "nbformat_minor": 2 } ================================================ FILE: examples/logging.ini ================================================ [loggers] keys=root,luigi,luigi-interface,gokart [handlers] keys=stderrHandler [formatters] keys=simpleFormatter [logger_root] level=INFO handlers=stderrHandler [logger_gokart] level=INFO handlers=stderrHandler qualname=gokart propagate=0 [logger_luigi] level=INFO handlers=stderrHandler qualname=luigi propagate=0 [logger_luigi-interface] level=INFO handlers=stderrHandler qualname=luigi-interface propagate=0 [handler_stderrHandler] class=StreamHandler formatter=simpleFormatter args=(sys.stdout,) [formatter_simpleFormatter] format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s datefmt=%Y/%m/%d %H:%M:%S class=logging.Formatter ================================================ FILE: examples/param.ini ================================================ [TaskOnKart] workspace_directory=./resource local_temporary_directory=./resource/tmp [core] logging_conf_file=logging.ini ================================================ FILE: gokart/__init__.py ================================================ __all__ = [ 'build', 'WorkerSchedulerFactory', 'make_tree_info', 'tree_info', 'PandasTypeConfig', 'ExplicitBoolParameter', 'ListTaskInstanceParameter', 'SerializableParameter', 'TaskInstanceParameter', 'ZonedDateSecondParameter', 'run', 'TaskOnKart', 'test_run', 'make_task_info_as_tree_str', 'add_config', 'delete_local_unnecessary_outputs', ] from gokart.build import WorkerSchedulerFactory, build from gokart.info import make_tree_info, tree_info from gokart.pandas_type_config import PandasTypeConfig from gokart.parameter import ( ExplicitBoolParameter, ListTaskInstanceParameter, SerializableParameter, TaskInstanceParameter, ZonedDateSecondParameter, ) from gokart.run import run from gokart.task import TaskOnKart from gokart.testing import test_run from gokart.tree.task_info import make_task_info_as_tree_str from gokart.utils import add_config from gokart.workspace_management import delete_local_unnecessary_outputs ================================================ FILE: gokart/build.py ================================================ from __future__ import annotations import enum import io import logging from dataclasses import dataclass from functools import partial from logging import getLogger from typing import Any, Literal, Protocol, TypeVar, cast, overload import backoff import luigi from luigi import LuigiStatusCode, rpc, scheduler import gokart import gokart.tree.task_info from gokart import worker from gokart.conflict_prevention_lock.task_lock import TaskLockException from gokart.target import TargetOnKart from gokart.task import TaskOnKart T = TypeVar('T') logger: logging.Logger = logging.getLogger(__name__) class LoggerConfig: def __init__(self, level: int): self.logger = getLogger(__name__) self.default_level = self.logger.level self.level = level def __enter__(self): logging.disable(self.level - 10) # subtract 10 to disable below self.level self.logger.setLevel(self.level) return self def __exit__(self, exception_type, exception_value, traceback): logging.disable(self.default_level - 10) # subtract 10 to disable below self.level self.logger.setLevel(self.default_level) class GokartBuildError(Exception): """Raised when ``gokart.build`` failed. This exception contains raised exceptions in the task execution.""" def __init__(self, message: str, raised_exceptions: dict[str, list[Exception]]) -> None: super().__init__(message) self.raised_exceptions = raised_exceptions class HasLockedTaskException(Exception): """Raised when the task failed to acquire the lock in the task execution.""" class TaskLockExceptionRaisedFlag: def __init__(self): self.flag: bool = False class WorkerProtocol(Protocol): """Protocol for Worker. This protocol is determined by luigi.worker.Worker. """ def add(self, task: TaskOnKart[Any]) -> bool: ... def run(self) -> bool: ... def __enter__(self) -> WorkerProtocol: ... def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: ... class WorkerSchedulerFactory: def create_local_scheduler(self) -> scheduler.Scheduler: return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False) def create_remote_scheduler(self, url: str) -> rpc.RemoteScheduler: return rpc.RemoteScheduler(url) def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant: bool = False) -> WorkerProtocol: return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant) def _get_output(task: TaskOnKart[T]) -> T: output = task.output() # FIXME: currently, nested output is not supported if isinstance(output, list) or isinstance(output, tuple): return cast(T, [t.load() for t in output if isinstance(t, TargetOnKart)]) if isinstance(output, dict): return cast(T, {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)}) if isinstance(output, TargetOnKart): return cast(T, output.load()) raise ValueError(f'output type is not supported: {type(output)}') def _reset_register(keep={'gokart', 'luigi'}): """reset luigi.task_register.Register._reg everytime gokart.build called to avoid TaskClassAmbigiousException""" luigi.task_register.Register._reg = [ x for x in luigi.task_register.Register._reg if ( (x.__module__.split('.')[0] in keep) # keep luigi and gokart or (issubclass(x, gokart.PandasTypeConfig)) ) # PandasTypeConfig should be kept ] class TaskDumpMode(enum.Enum): TREE = 'tree' TABLE = 'table' NONE = 'none' class TaskDumpOutputType(enum.Enum): PRINT = 'print' DUMP = 'dump' NONE = 'none' @dataclass class TaskDumpConfig: mode: TaskDumpMode = TaskDumpMode.NONE output_type: TaskDumpOutputType = TaskDumpOutputType.NONE def process_task_info(task: TaskOnKart[Any], task_dump_config: TaskDumpConfig = TaskDumpConfig()) -> None: match task_dump_config: case TaskDumpConfig(mode=TaskDumpMode.NONE, output_type=TaskDumpOutputType.NONE): pass case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.PRINT): tree = gokart.make_tree_info(task) logger.info(tree) case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.PRINT): table = gokart.tree.task_info.make_task_info_as_table(task) output = io.StringIO() table.to_csv(output, index=False, sep='\t') output.seek(0) logger.info(output.read()) case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.DUMP): tree = gokart.make_tree_info(task) gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.txt').dump(tree) case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.DUMP): table = gokart.tree.task_info.make_task_info_as_table(task) gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.pkl').dump(table) case _: raise ValueError(f'Unsupported TaskDumpConfig: {task_dump_config}') @overload def build( task: TaskOnKart[T], return_value: Literal[True] = True, reset_register: bool = True, log_level: int = logging.ERROR, task_lock_exception_max_tries: int = 10, task_lock_exception_max_wait_seconds: int = 600, **env_params: Any, ) -> T: ... @overload def build( task: TaskOnKart[T], return_value: Literal[False], reset_register: bool = True, log_level: int = logging.ERROR, task_lock_exception_max_tries: int = 10, task_lock_exception_max_wait_seconds: int = 600, **env_params: Any, ) -> None: ... def build( task: TaskOnKart[T], return_value: bool = True, reset_register: bool = True, log_level: int = logging.ERROR, task_lock_exception_max_tries: int = 10, task_lock_exception_max_wait_seconds: int = 600, task_dump_config: TaskDumpConfig = TaskDumpConfig(), **env_params: Any, ) -> T | None: """ Run gokart task for local interpreter. Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build) """ if reset_register: _reset_register() with LoggerConfig(level=log_level): log_handler_before_run = logging.StreamHandler() logger.addHandler(log_handler_before_run) process_task_info(task, task_dump_config) logger.removeHandler(log_handler_before_run) log_handler_before_run.close() task_lock_exception_raised = TaskLockExceptionRaisedFlag() raised_exceptions: dict[str, list[Exception]] = dict() @TaskOnKart.event_handler(luigi.Event.FAILURE) def when_failure(task, exception): if isinstance(exception, TaskLockException): task_lock_exception_raised.flag = True else: raised_exceptions.setdefault(task.make_unique_id(), []).append(exception) @backoff.on_exception( partial(backoff.expo, max_value=task_lock_exception_max_wait_seconds), HasLockedTaskException, max_tries=task_lock_exception_max_tries ) def _build_task(): task_lock_exception_raised.flag = False result = luigi.build( [task], worker_scheduler_factory=WorkerSchedulerFactory(), local_scheduler=True, detailed_summary=True, log_level=logging.getLevelName(log_level), **env_params, ) if task_lock_exception_raised.flag: raise HasLockedTaskException() if result.status in (LuigiStatusCode.FAILED, LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED, LuigiStatusCode.SCHEDULING_FAILED): raise GokartBuildError(result.summary_text, raised_exceptions=raised_exceptions) return _get_output(task) if return_value else None return cast(T | None, _build_task()) ================================================ FILE: gokart/config_params.py ================================================ from __future__ import annotations from typing import Any import luigi import gokart class inherits_config_params: def __init__(self, config_class: type[luigi.Config], parameter_alias: dict[str, str] | None = None): """ Decorates task to inherit parameter value of `config_class`. * config_class: Inherit parameter value of this task to decorated task. Only parameter values exist in both tasks are inherited. * parameter_alias: Dictionary to map paramter names between config_class task and decorated task. key: config_class's parameter name. value: decorated task's parameter name. """ self._config_class: type[luigi.Config] = config_class self._parameter_alias: dict[str, str] = parameter_alias if parameter_alias is not None else {} def __call__(self, task_class: type[gokart.TaskOnKart[Any]]) -> type[gokart.TaskOnKart[Any]]: # wrap task to prevent task name from being changed @luigi.task._task_wraps(task_class) class Wrapped(task_class): # type: ignore @classmethod def get_param_values(cls, params, args, kwargs): for param_key, param_value in self._config_class().param_kwargs.items(): task_param_key = self._parameter_alias.get(param_key, param_key) if hasattr(cls, task_param_key) and task_param_key not in kwargs: kwargs[task_param_key] = param_value return super().get_param_values(params, args, kwargs) return Wrapped ================================================ FILE: gokart/conflict_prevention_lock/task_lock.py ================================================ from __future__ import annotations import functools import os from logging import getLogger from typing import Any, NamedTuple import redis from apscheduler.schedulers.background import BackgroundScheduler logger = getLogger(__name__) class TaskLockParams(NamedTuple): redis_host: str | None redis_port: int | None redis_timeout: int | None redis_key: str should_task_lock: bool raise_task_lock_exception_on_collision: bool lock_extend_seconds: int class TaskLockException(Exception): pass """Raised when the task failed to acquire the lock in the task execution. Only used internally.""" class RedisClient: _instances: dict[Any, Any] = {} def __new__(cls, *args, **kwargs): key = (args, tuple(sorted(kwargs.items()))) if cls not in cls._instances: cls._instances[cls] = {} if key not in cls._instances[cls]: cls._instances[cls][key] = super().__new__(cls) return cls._instances[cls][key] def __init__(self, host: str | None, port: int | None) -> None: if not hasattr(self, '_redis_client'): host = host or 'localhost' port = port or 6379 self._redis_client = redis.Redis(host=host, port=port) def get_redis_client(self): return self._redis_client def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int) -> None: task_lock.extend(additional_time=redis_timeout, replace_ttl=True) def set_task_lock(task_lock_params: TaskLockParams) -> redis.lock.Lock: redis_client = RedisClient(host=task_lock_params.redis_host, port=task_lock_params.redis_port).get_redis_client() blocking = not task_lock_params.raise_task_lock_exception_on_collision task_lock = redis.lock.Lock(redis=redis_client, name=task_lock_params.redis_key, timeout=task_lock_params.redis_timeout, thread_local=False) if not task_lock.acquire(blocking=blocking): raise TaskLockException('Lock already taken by other task.') return task_lock def set_lock_scheduler(task_lock: redis.lock.Lock, task_lock_params: TaskLockParams) -> BackgroundScheduler: scheduler = BackgroundScheduler() extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=task_lock_params.redis_timeout or 0) scheduler.add_job( extend_lock, 'interval', seconds=task_lock_params.lock_extend_seconds, max_instances=999999999, misfire_grace_time=task_lock_params.redis_timeout, coalesce=False, ) scheduler.start() return scheduler def make_task_lock_key(file_path: str, unique_id: str | None) -> str: basename_without_ext = os.path.splitext(os.path.basename(file_path))[0] return f'{basename_without_ext}_{unique_id}' def make_task_lock_params( file_path: str, unique_id: str | None, redis_host: str | None = None, redis_port: int | None = None, redis_timeout: int | None = None, raise_task_lock_exception_on_collision: bool = False, lock_extend_seconds: int = 10, ) -> TaskLockParams: redis_key = make_task_lock_key(file_path, unique_id) should_task_lock = redis_host is not None and redis_port is not None if redis_timeout is not None: assert redis_timeout > lock_extend_seconds, f'`redis_timeout` must be set greater than lock_extend_seconds:{lock_extend_seconds}, not {redis_timeout}.' task_lock_params = TaskLockParams( redis_host=redis_host, redis_port=redis_port, redis_key=redis_key, should_task_lock=should_task_lock, redis_timeout=redis_timeout, raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision, lock_extend_seconds=lock_extend_seconds, ) return task_lock_params def make_task_lock_params_for_run(task_self: Any, lock_extend_seconds: int = 10) -> TaskLockParams: task_path_name = os.path.join(task_self.__module__.replace('.', '/'), f'{type(task_self).__name__}') unique_id = task_self.make_unique_id() + '-run' task_lock_key = make_task_lock_key(file_path=task_path_name, unique_id=unique_id) should_task_lock = task_self.redis_host is not None and task_self.redis_port is not None return TaskLockParams( redis_host=task_self.redis_host, redis_port=task_self.redis_port, redis_key=task_lock_key, should_task_lock=should_task_lock, redis_timeout=task_self.redis_timeout, raise_task_lock_exception_on_collision=True, lock_extend_seconds=lock_extend_seconds, ) ================================================ FILE: gokart/conflict_prevention_lock/task_lock_wrappers.py ================================================ from __future__ import annotations import functools from collections.abc import Callable from logging import getLogger from typing import ParamSpec, TypeVar from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock logger = getLogger(__name__) P = ParamSpec('P') R = TypeVar('R') def wrap_dump_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams, exist_check: Callable[..., bool]) -> Callable[P, R | None]: """Redis lock wrapper function for TargetOnKart.dump(). When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check. https://github.com/m3dev/gokart/issues/265 """ if not task_lock_params.should_task_lock: return func def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None: task_lock = set_task_lock(task_lock_params=task_lock_params) scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) try: logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} locked.') if not exist_check(): return func(*args, **kwargs) return None finally: logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} released.') task_lock.release() scheduler.shutdown() return wrapper def wrap_load_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]: """Redis lock wrapper function for TargetOnKart.load(). When TargetOnKart.load() is called, redis lock will be locked and released before load(). https://github.com/m3dev/gokart/issues/265 """ if not task_lock_params.should_task_lock: return func def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: task_lock = set_task_lock(task_lock_params=task_lock_params) scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} locked.') task_lock.release() logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} released.') scheduler.shutdown() result = func(*args, **kwargs) return result return wrapper def wrap_remove_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]: """Redis lock wrapper function for TargetOnKart.remove(). When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock. https://github.com/m3dev/gokart/issues/265 """ if not task_lock_params.should_task_lock: return func def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: task_lock = set_task_lock(task_lock_params=task_lock_params) scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) try: logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} locked.') result = func(*args, **kwargs) task_lock.release() logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released.') scheduler.shutdown() return result except BaseException as e: logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released with BaseException.') task_lock.release() scheduler.shutdown() raise e return wrapper def wrap_run_with_lock(run_func: Callable[[], R], task_lock_params: TaskLockParams) -> Callable[[], R]: @functools.wraps(run_func) def wrapped(): task_lock = set_task_lock(task_lock_params=task_lock_params) scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params) try: logger.debug(f'Task RUN lock of {task_lock_params.redis_key} locked.') result = run_func() task_lock.release() logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released.') scheduler.shutdown() return result except BaseException as e: logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released with BaseException.') task_lock.release() scheduler.shutdown() raise e return wrapped ================================================ FILE: gokart/errors/__init__.py ================================================ from gokart.build import GokartBuildError, HasLockedTaskException from gokart.pandas_type_config import PandasTypeError from gokart.task import EmptyDumpError __all__ = [ 'GokartBuildError', 'HasLockedTaskException', 'PandasTypeError', 'EmptyDumpError', ] ================================================ FILE: gokart/file_processor/__init__.py ================================================ """File processor module with support for multiple DataFrame backends.""" from __future__ import annotations import os from typing import Any, Literal # Export common processors and types from base from gokart.file_processor.base import ( BinaryFileProcessor, DataFrameType, FileProcessor, GzipFileProcessor, NpzFileProcessor, PickleFileProcessor, TextFileProcessor, XmlFileProcessor, ) # Import backend-specific implementations from gokart.file_processor.pandas import ( CsvFileProcessorPandas, FeatherFileProcessorPandas, JsonFileProcessorPandas, ParquetFileProcessorPandas, ) from gokart.file_processor.polars import ( CsvFileProcessorPolars, FeatherFileProcessorPolars, JsonFileProcessorPolars, ParquetFileProcessorPolars, ) class CsvFileProcessor(FileProcessor): """CSV file processor with automatic backend selection based on dataframe_type.""" def __init__(self, sep: str = ',', encoding: str = 'utf-8', dataframe_type: DataFrameType = 'pandas') -> None: """ CSV file processor with support for both pandas and polars DataFrames. Args: sep: CSV delimiter (default: ',') encoding: File encoding (default: 'utf-8') dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas') """ self._sep = sep self._encoding = encoding self._dataframe_type = dataframe_type # Store for tests if dataframe_type == 'polars-lazy': self._impl: FileProcessor = CsvFileProcessorPolars(sep=sep, encoding=encoding, lazy=True) elif dataframe_type == 'polars': self._impl = CsvFileProcessorPolars(sep=sep, encoding=encoding, lazy=False) else: self._impl = CsvFileProcessorPandas(sep=sep, encoding=encoding) def format(self): return self._impl.format() def load(self, file): return self._impl.load(file) def dump(self, obj, file): return self._impl.dump(obj, file) class JsonFileProcessor(FileProcessor): """JSON file processor with automatic backend selection based on dataframe_type.""" def __init__(self, orient: Literal['split', 'records', 'index', 'table', 'columns', 'values'] | None = None, dataframe_type: DataFrameType = 'pandas'): """ JSON file processor with support for both pandas and polars DataFrames. Args: orient: JSON orientation. 'records' for newline-delimited JSON. dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas') """ self._orient = orient self._dataframe_type = dataframe_type # Store for tests if dataframe_type == 'polars-lazy': self._impl: FileProcessor = JsonFileProcessorPolars(orient=orient, lazy=True) elif dataframe_type == 'polars': self._impl = JsonFileProcessorPolars(orient=orient, lazy=False) else: self._impl = JsonFileProcessorPandas(orient=orient) def format(self): return self._impl.format() def load(self, file): return self._impl.load(file) def dump(self, obj, file): return self._impl.dump(obj, file) class ParquetFileProcessor(FileProcessor): """Parquet file processor with automatic backend selection based on dataframe_type.""" def __init__(self, engine: Any = 'pyarrow', compression: Any = None, dataframe_type: DataFrameType = 'pandas') -> None: """ Parquet file processor with support for both pandas and polars DataFrames. Args: engine: Parquet engine (pandas-specific, ignored for polars). compression: Compression type. dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas') """ self._engine = engine self._compression = compression self._dataframe_type = dataframe_type # Store for tests if dataframe_type == 'polars-lazy': self._impl: FileProcessor = ParquetFileProcessorPolars(engine=engine, compression=compression, lazy=True) elif dataframe_type == 'polars': self._impl = ParquetFileProcessorPolars(engine=engine, compression=compression, lazy=False) else: self._impl = ParquetFileProcessorPandas(engine=engine, compression=compression) def format(self): return self._impl.format() def load(self, file): return self._impl.load(file) def dump(self, obj, file): # Use the configured implementation (pandas by default) return self._impl.dump(obj, file) class FeatherFileProcessor(FileProcessor): """Feather file processor with automatic backend selection based on dataframe_type.""" def __init__(self, store_index_in_feather: bool, dataframe_type: DataFrameType = 'pandas'): """ Feather file processor with support for both pandas and polars DataFrames. Args: store_index_in_feather: Whether to store pandas index (pandas-only feature). dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas') """ self._store_index_in_feather = store_index_in_feather self._dataframe_type = dataframe_type # Store for tests if dataframe_type == 'polars-lazy': self._impl: FileProcessor = FeatherFileProcessorPolars(store_index_in_feather=store_index_in_feather, lazy=True) elif dataframe_type == 'polars': self._impl = FeatherFileProcessorPolars(store_index_in_feather=store_index_in_feather, lazy=False) else: self._impl = FeatherFileProcessorPandas(store_index_in_feather=store_index_in_feather) def format(self): return self._impl.format() def load(self, file): return self._impl.load(file) def dump(self, obj, file): # Use the configured implementation (pandas by default) return self._impl.dump(obj, file) def make_file_processor(file_path: str, store_index_in_feather: bool = True, *, dataframe_type: DataFrameType = 'pandas') -> FileProcessor: """Create a file processor based on file extension with default parameters.""" extension2processor = { '.txt': TextFileProcessor(), '.ini': TextFileProcessor(), '.csv': CsvFileProcessor(sep=',', dataframe_type=dataframe_type), '.tsv': CsvFileProcessor(sep='\t', dataframe_type=dataframe_type), '.pkl': PickleFileProcessor(), '.gz': GzipFileProcessor(), '.json': JsonFileProcessor(dataframe_type=dataframe_type), '.ndjson': JsonFileProcessor(dataframe_type=dataframe_type, orient='records'), '.xml': XmlFileProcessor(), '.npz': NpzFileProcessor(), '.parquet': ParquetFileProcessor(compression='gzip', dataframe_type=dataframe_type), '.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather, dataframe_type=dataframe_type), '.png': BinaryFileProcessor(), '.jpg': BinaryFileProcessor(), } extension = os.path.splitext(file_path)[1] assert extension in extension2processor, f'{extension} is not supported. The supported extensions are {list(extension2processor.keys())}.' return extension2processor[extension] __all__ = [ # Base classes and types 'FileProcessor', 'DataFrameType', # Common processors 'BinaryFileProcessor', 'PickleFileProcessor', 'TextFileProcessor', 'GzipFileProcessor', 'XmlFileProcessor', 'NpzFileProcessor', # DataFrame processors (with factory pattern) 'CsvFileProcessor', 'JsonFileProcessor', 'ParquetFileProcessor', 'FeatherFileProcessor', # Utility functions 'make_file_processor', ] ================================================ FILE: gokart/file_processor/base.py ================================================ from __future__ import annotations import xml.etree.ElementTree as ET from abc import abstractmethod from io import BytesIO from logging import getLogger from typing import Any, Literal, cast import dill import luigi import luigi.format import numpy as np from gokart.utils import load_dill_with_pandas_backward_compatibility logger = getLogger(__name__) # Type alias for DataFrame library return type DataFrameType = Literal['pandas', 'polars', 'polars-lazy'] class FileProcessor: @abstractmethod def format(self) -> Any: ... @abstractmethod def load(self, file: Any) -> Any: ... @abstractmethod def dump(self, obj: Any, file: Any) -> None: ... class BinaryFileProcessor(FileProcessor): """ Pass bytes to this processor ``` figure_binary = io.BytesIO() plt.savefig(figure_binary) figure_binary.seek(0) BinaryFileProcessor().dump(figure_binary.read()) ``` """ def format(self): return luigi.format.Nop def load(self, file): return file.read() def dump(self, obj, file): file.write(obj) class _ChunkedLargeFileReader: def __init__(self, file: Any) -> None: self._file = file def __getattr__(self, item): return getattr(self._file, item) def read(self, n: int) -> bytes: if n >= (1 << 31): logger.info(f'reading a large file with total_bytes={n}.') buffer = bytearray(n) idx = 0 while idx < n: batch_size = min(n - idx, (1 << 31) - 1) logger.info(f'reading bytes [{idx}, {idx + batch_size})...') buffer[idx : idx + batch_size] = self._file.read(batch_size) idx += batch_size logger.info('done.') return bytes(buffer) return cast(bytes, self._file.read(n)) def readline(self) -> bytes: return cast(bytes, self._file.readline()) def seek(self, offset: int) -> None: self._file.seek(offset) def seekable(self) -> bool: return cast(bool, self._file.seekable()) class PickleFileProcessor(FileProcessor): def format(self): return luigi.format.Nop def load(self, file): if not file.seekable(): # load_dill_with_pandas_backward_compatibility() requires file with seek() and readlines() implemented. # Therefore, we need to wrap with BytesIO which makes file seekable and readlinesable. # For example, ReadableS3File is not a seekable file. return load_dill_with_pandas_backward_compatibility(BytesIO(file.read())) return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file)) def dump(self, obj, file): self._write(dill.dumps(obj, protocol=4), file) @staticmethod def _write(buffer, file): n = len(buffer) idx = 0 while idx < n: logger.info(f'writing a file with total_bytes={n}...') batch_size = min(n - idx, (1 << 31) - 1) logger.info(f'writing bytes [{idx}, {idx + batch_size})') file.write(buffer[idx : idx + batch_size]) idx += batch_size logger.info('done') class TextFileProcessor(FileProcessor): def format(self): return None def load(self, file): return [s.rstrip() for s in file.readlines()] def dump(self, obj, file): if isinstance(obj, list): for x in obj: file.write(str(x) + '\n') else: file.write(str(obj)) class GzipFileProcessor(FileProcessor): def format(self): return luigi.format.Gzip def load(self, file): return [s.rstrip().decode() for s in file.readlines()] def dump(self, obj, file): if isinstance(obj, list): for x in obj: file.write((str(x) + '\n').encode()) else: file.write(str(obj).encode()) class XmlFileProcessor(FileProcessor): def format(self): return None def load(self, file): try: return ET.parse(file) except ET.ParseError: return ET.ElementTree() def dump(self, obj, file): assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.' obj.write(file) class NpzFileProcessor(FileProcessor): def format(self): return luigi.format.Nop def load(self, file): return np.load(file)['data'] def dump(self, obj, file): assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.' np.savez_compressed(file, data=obj) ================================================ FILE: gokart/file_processor/pandas.py ================================================ """Pandas-specific file processor implementations.""" from __future__ import annotations from io import BytesIO from typing import Literal import luigi import luigi.format import pandas as pd from luigi.format import TextFormat from gokart.file_processor.base import FileProcessor from gokart.object_storage import ObjectStorage class CsvFileProcessorPandas(FileProcessor): """CSV file processor for pandas DataFrames.""" def __init__(self, sep: str = ',', encoding: str = 'utf-8') -> None: self._sep = sep self._encoding = encoding super().__init__() def format(self): return TextFormat(encoding=self._encoding) def load(self, file): try: return pd.read_csv(file, sep=self._sep, encoding=self._encoding) except pd.errors.EmptyDataError: return pd.DataFrame() def dump(self, obj, file): if not isinstance(obj, pd.DataFrame | pd.Series): raise TypeError(f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.') obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding) _JsonOrient = Literal['split', 'records', 'index', 'table', 'columns', 'values'] class JsonFileProcessorPandas(FileProcessor): """JSON file processor for pandas DataFrames.""" def __init__(self, orient: _JsonOrient | None = None): self._orient: _JsonOrient | None = orient def format(self): return luigi.format.Nop def load(self, file): try: return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) except pd.errors.EmptyDataError: return pd.DataFrame() def dump(self, obj, file): if isinstance(obj, dict): obj = pd.DataFrame.from_dict(obj) if not isinstance(obj, pd.DataFrame | pd.Series): raise TypeError(f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.') obj.to_json(file, orient=self._orient, lines=True if self._orient == 'records' else False) class ParquetFileProcessorPandas(FileProcessor): """Parquet file processor for pandas DataFrames.""" def __init__(self, engine: Literal['auto', 'pyarrow', 'fastparquet'] = 'pyarrow', compression: str | None = None) -> None: self._engine: Literal['auto', 'pyarrow', 'fastparquet'] = engine self._compression = compression super().__init__() def format(self): return luigi.format.Nop def load(self, file): # FIXME(mamo3gr): enable streaming (chunked) read with S3. # pandas.read_parquet accepts file-like object # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, # which is needed for pandas to read a file in chunks. if ObjectStorage.is_buffered_reader(file): return pd.read_parquet(file.name) else: return pd.read_parquet(BytesIO(file.read())) def dump(self, obj, file): if not isinstance(obj, pd.DataFrame): raise TypeError(f'requires pd.DataFrame, but {type(obj)} is passed.') # MEMO: to_parquet only supports a filepath as string (not a file handle) obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression) class FeatherFileProcessorPandas(FileProcessor): """Feather file processor for pandas DataFrames.""" def __init__(self, store_index_in_feather: bool): super().__init__() self._store_index_in_feather = store_index_in_feather self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__' def format(self): return luigi.format.Nop def load(self, file): # FIXME(mamo3gr): enable streaming (chunked) read with S3. # pandas.read_feather accepts file-like object # but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method, # which is needed for pandas to read a file in chunks. if ObjectStorage.is_buffered_reader(file): loaded_df = pd.read_feather(file.name) else: loaded_df = pd.read_feather(BytesIO(file.read())) if self._store_index_in_feather: if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns): index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX] index_column = index_columns[0] index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :] if index_name == 'None': index_name = None loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name) loaded_df = loaded_df.drop(columns=[index_column]) return loaded_df def dump(self, obj, file): if not isinstance(obj, pd.DataFrame): raise TypeError(f'requires pd.DataFrame, but {type(obj)} is passed.') dump_obj = obj.copy() if self._store_index_in_feather: index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}' assert index_column_name not in dump_obj.columns, ( f'column name {index_column_name} already exists in dump_obj. \nConsider not saving index by setting store_index_in_feather=False.' ) assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.' dump_obj[index_column_name] = dump_obj.index dump_obj = dump_obj.reset_index(drop=True) # to_feather supports "binary" file-like object, but file variable is text dump_obj.to_feather(file.name) ================================================ FILE: gokart/file_processor/polars.py ================================================ """Polars-specific file processor implementations.""" from __future__ import annotations from io import BytesIO from typing import TYPE_CHECKING, Literal import luigi import luigi.format from luigi.format import TextFormat from gokart.file_processor.base import FileProcessor from gokart.object_storage import ObjectStorage _CsvEncoding = Literal['utf8', 'utf8-lossy'] _ParquetCompression = Literal['lz4', 'uncompressed', 'snappy', 'gzip', 'brotli', 'zstd'] try: import polars as pl HAS_POLARS = True except ImportError: HAS_POLARS = False if TYPE_CHECKING: import polars as pl class CsvFileProcessorPolars(FileProcessor): """CSV file processor for polars DataFrames.""" def __init__(self, sep: str = ',', encoding: str = 'utf-8', lazy: bool = False) -> None: if not HAS_POLARS: raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars") self._sep = sep self._encoding = encoding self._lazy = lazy super().__init__() def format(self): return TextFormat(encoding=self._encoding) def load(self, file): try: # scan_csv/read_csv only support 'utf8' and 'utf8-lossy' encoding: _CsvEncoding = 'utf8' if self._encoding in ('utf-8', 'utf8') else 'utf8-lossy' if self._lazy: # scan_csv requires a file path, not a file object return pl.scan_csv(file.name, separator=self._sep, encoding=encoding) return pl.read_csv(file, separator=self._sep, encoding=encoding) except Exception as e: # Handle empty data gracefully if 'empty' in str(e).lower() or 'no data' in str(e).lower(): return pl.LazyFrame() if self._lazy else pl.DataFrame() raise def dump(self, obj, file): if isinstance(obj, pl.LazyFrame): obj = obj.collect() if not isinstance(obj, pl.DataFrame): raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.') obj.write_csv(file, separator=self._sep, include_header=True) class JsonFileProcessorPolars(FileProcessor): """JSON file processor for polars DataFrames.""" def __init__(self, orient: str | None = None, lazy: bool = False): if not HAS_POLARS: raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars") self._orient = orient self._lazy = lazy def format(self): return luigi.format.Nop def load(self, file): try: if self._orient == 'records': if self._lazy: return pl.scan_ndjson(file) return pl.read_ndjson(file) else: # polars doesn't have scan_json, so we read and convert if lazy df = pl.read_json(file) return df.lazy() if self._lazy else df except Exception as e: # Handle empty files if 'empty' in str(e).lower() or 'no data' in str(e).lower(): return pl.LazyFrame() if self._lazy else pl.DataFrame() raise def dump(self, obj, file): if isinstance(obj, pl.LazyFrame): obj = obj.collect() if not isinstance(obj, pl.DataFrame): raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.') if self._orient == 'records': obj.write_ndjson(file) else: obj.write_json(file) class ParquetFileProcessorPolars(FileProcessor): """Parquet file processor for polars DataFrames.""" def __init__(self, engine: str = 'pyarrow', compression: _ParquetCompression | None = None, lazy: bool = False) -> None: if not HAS_POLARS: raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars") self._engine = engine # Ignored for polars self._compression: _ParquetCompression | None = compression self._lazy = lazy super().__init__() def format(self): return luigi.format.Nop def load(self, file): # polars.read_parquet can handle file paths or file-like objects if ObjectStorage.is_buffered_reader(file): if self._lazy: return pl.scan_parquet(file.name) return pl.read_parquet(file.name) else: data = BytesIO(file.read()) if self._lazy: # scan_parquet doesn't work with BytesIO, so read and convert return pl.read_parquet(data).lazy() return pl.read_parquet(data) def dump(self, obj, file): if isinstance(obj, pl.LazyFrame): obj = obj.collect() if not isinstance(obj, pl.DataFrame): raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.') # polars write_parquet requires a file path; default to 'zstd' when compression is None obj.write_parquet(file.name, compression=self._compression or 'zstd') class FeatherFileProcessorPolars(FileProcessor): """Feather file processor for polars DataFrames.""" def __init__(self, store_index_in_feather: bool, lazy: bool = False): if not HAS_POLARS: raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars") super().__init__() self._store_index_in_feather = store_index_in_feather # Ignored for polars self._lazy = lazy def format(self): return luigi.format.Nop def load(self, file): # polars uses read_ipc for feather format if ObjectStorage.is_buffered_reader(file): if self._lazy: return pl.scan_ipc(file.name) return pl.read_ipc(file.name) else: data = BytesIO(file.read()) if self._lazy: # scan_ipc doesn't work with BytesIO, so read and convert return pl.read_ipc(data).lazy() return pl.read_ipc(data) def dump(self, obj, file): if isinstance(obj, pl.LazyFrame): obj = obj.collect() if not isinstance(obj, pl.DataFrame): raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.') # polars uses write_ipc for feather format # Note: store_index_in_feather is ignored for polars as it's pandas-specific obj.write_ipc(file.name) ================================================ FILE: gokart/file_processor.py ================================================ ================================================ FILE: gokart/gcs_config.py ================================================ from __future__ import annotations import json import os from typing import cast import luigi import luigi.contrib.gcs from google.oauth2.service_account import Credentials class GCSConfig(luigi.Config): gcs_credential_name: luigi.StrParameter = luigi.StrParameter(default='GCS_CREDENTIAL', description='GCS credential environment variable.') _client = None def get_gcs_client(self) -> luigi.contrib.gcs.GCSClient: if self._client is None: # use cache as like singleton object self._client = self._get_gcs_client() return self._client def _get_gcs_client(self) -> luigi.contrib.gcs.GCSClient: return luigi.contrib.gcs.GCSClient(oauth_credentials=self._load_oauth_credentials()) def _load_oauth_credentials(self) -> Credentials | None: json_str = os.environ.get(self.gcs_credential_name) if not json_str: return None if os.path.isfile(json_str): return cast(Credentials, Credentials.from_service_account_file(json_str)) return cast(Credentials, Credentials.from_service_account_info(json.loads(json_str))) ================================================ FILE: gokart/gcs_obj_metadata_client.py ================================================ from __future__ import annotations import copy import functools import json import re from collections.abc import Iterable from logging import getLogger from typing import Any, Final from urllib.parse import urlsplit from googleapiclient.model import makepatch from gokart.gcs_config import GCSConfig from gokart.required_task_output import RequiredTaskOutput from gokart.utils import FlattenableItems logger = getLogger(__name__) class GCSObjectMetadataClient: """ This class is Utility-Class, so should not be initialized. This class used for adding metadata as labels. """ # Maximum metadata size for GCS objects (8 KiB) MAX_GCS_METADATA_SIZE: Final[int] = 8 * 1024 @staticmethod def _is_log_related_path(path: str) -> bool: return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None # This is the copied method of luigi.gcs._path_to_bucket_and_key(path). @staticmethod def _path_to_bucket_and_key(path: str) -> tuple[str, str]: (scheme, netloc, path, _, _) = urlsplit(path) assert scheme == 'gs' path_without_initial_slash = path[1:] return netloc, path_without_initial_slash @staticmethod def add_task_state_labels( path: str, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: if GCSObjectMetadataClient._is_log_related_path(path): return # In gokart/object_storage.get_time_stamp, could find same call. # _path_to_bucket_and_key is a private method, so, this might not be acceptable. bucket, obj = GCSObjectMetadataClient._path_to_bucket_and_key(path) _response = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute() if _response is None: logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.') return response: dict[str, Any] = dict(_response) original_metadata: dict[Any, Any] = {} if 'metadata' in response.keys(): _metadata = response.get('metadata') if _metadata is not None: original_metadata = dict(_metadata) patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata( copy.deepcopy(original_metadata), task_params, custom_labels, required_task_outputs, ) if original_metadata != patched_metadata: # If we use update api, existing object metadata are removed, so should use patch api. # See the official document descriptions. # [Link] https://cloud.google.com/storage/docs/viewing-editing-metadata?hl=ja#rest-set-object-metadata update_response = ( GCSConfig() .get_gcs_client() .client.objects() .patch( bucket=bucket, object=obj, body=makepatch({'metadata': original_metadata}, {'metadata': patched_metadata}), ) .execute() ) if update_response is None: logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.') @staticmethod def _normalize_labels(labels: dict[str, Any] | None) -> dict[str, str]: return {str(key): str(value) for key, value in labels.items()} if labels else {} @staticmethod def _get_patched_obj_metadata( metadata: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> dict[str, Any] | Any: # If metadata from response when getting bucket and object information is not dictionary, # something wrong might be happened, so return original metadata, no patched. if not isinstance(metadata, dict): logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.') return metadata # Maximum size of metadata for each object is 8 KiB. # [Link]: https://cloud.google.com/storage/quotas#objects normalized_task_params_labels = GCSObjectMetadataClient._normalize_labels(task_params) normalized_custom_labels = GCSObjectMetadataClient._normalize_labels(custom_labels) # There is a possibility that the keys of user-provided labels(custom_labels) may conflict with those generated from task parameters (task_params_labels). # However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters. # Instead, users are expected to search using the labels they provided. # Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence. normalized_labels = [normalized_custom_labels, normalized_task_params_labels] if required_task_outputs: normalized_labels.append({'__required_task_outputs': json.dumps(GCSObjectMetadataClient._get_serialized_string(required_task_outputs))}) _merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels) return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels) @staticmethod def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]: if isinstance(required_task_outputs, RequiredTaskOutput): return required_task_outputs.serialize() elif isinstance(required_task_outputs, dict): return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()} elif isinstance(required_task_outputs, Iterable): return [GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs] else: raise TypeError( f'Unsupported type for required_task_outputs: {type(required_task_outputs)}. ' 'It should be RequiredTaskOutput, dict, or iterable of RequiredTaskOutput.' ) @staticmethod def _merge_custom_labels_and_task_params_labels( normalized_labels_list: list[dict[str, str]], ) -> dict[str, str]: def __merge_two_dicts_helper(merged: dict[str, str], current_labels: dict[str, str]) -> dict[str, str]: next_merged = copy.deepcopy(merged) for label_name, label_value in current_labels.items(): if len(label_value) == 0: logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.') continue if label_name in next_merged: logger.warning(f'label_name={label_name} is already seen. So skip to add as metadata.') continue next_merged[label_name] = label_value return next_merged return functools.reduce(__merge_two_dicts_helper, normalized_labels_list, {}) # Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB. # So, we need to adjust the size of metadata. @staticmethod def _adjust_gcs_metadata_limit_size(_labels: dict[str, str]) -> dict[str, str]: def _get_label_size(label_name: str, label_value: str) -> int: return len(label_name.encode('utf-8')) + len(label_value.encode('utf-8')) labels = copy.deepcopy(_labels) max_gcs_metadata_size, current_total_metadata_size = ( GCSObjectMetadataClient.MAX_GCS_METADATA_SIZE, sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()), ) if current_total_metadata_size <= max_gcs_metadata_size: return labels # NOTE: remove labels to stay within max metadata size. to_remove = [] for label_name, label_value in reversed(tuple(labels.items())): size = _get_label_size(label_name, label_value) to_remove.append(label_name) current_total_metadata_size -= size if current_total_metadata_size <= max_gcs_metadata_size: break for key in to_remove: del labels[key] return labels ================================================ FILE: gokart/gcs_zip_client.py ================================================ from __future__ import annotations import os import shutil from typing import cast from gokart.gcs_config import GCSConfig from gokart.zip_client import ZipClient, _unzip_file class GCSZipClient(ZipClient): def __init__(self, file_path: str, temporary_directory: str) -> None: self._file_path = file_path self._temporary_directory = temporary_directory self._client = GCSConfig().get_gcs_client() def exists(self) -> bool: return cast(bool, self._client.exists(self._file_path)) def make_archive(self) -> None: extension = os.path.splitext(self._file_path)[1] shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory) self._client.put(self._temporary_file_path(), self._file_path) def unpack_archive(self) -> None: os.makedirs(self._temporary_directory, exist_ok=True) file_pointer = self._client.download(self._file_path) _unzip_file(fp=file_pointer, extract_dir=self._temporary_directory) def remove(self) -> None: self._client.remove(self._file_path) @property def path(self) -> str: return self._file_path def _temporary_file_path(self): extension = os.path.splitext(self._file_path)[1] base_name = self._temporary_directory if base_name.endswith('/'): base_name = base_name[:-1] return base_name + extension ================================================ FILE: gokart/in_memory/__init__.py ================================================ __all__ = [ 'InMemoryCacheRepository', 'InMemoryTarget', 'make_in_memory_target', ] from .repository import InMemoryCacheRepository from .target import InMemoryTarget, make_in_memory_target ================================================ FILE: gokart/in_memory/data.py ================================================ from __future__ import annotations from dataclasses import dataclass from datetime import datetime from typing import Any @dataclass class InMemoryData: value: Any last_modification_time: datetime @classmethod def create_data(self, value: Any) -> InMemoryData: return InMemoryData(value=value, last_modification_time=datetime.now()) ================================================ FILE: gokart/in_memory/repository.py ================================================ from __future__ import annotations from collections.abc import Iterator from datetime import datetime from typing import Any from .data import InMemoryData class InMemoryCacheRepository: _cache: dict[str, InMemoryData] = {} def __init__(self): pass def get_value(self, key: str) -> Any: return self._get_data(key).value def get_last_modification_time(self, key: str) -> datetime: return self._get_data(key).last_modification_time def _get_data(self, key: str) -> InMemoryData: return self._cache[key] def set_value(self, key: str, obj: Any) -> None: data = InMemoryData.create_data(obj) self._cache[key] = data def has(self, key: str) -> bool: return key in self._cache def remove(self, key: str) -> None: assert self.has(key), f'{key} does not exist.' del self._cache[key] def empty(self) -> bool: return not self._cache def clear(self) -> None: self._cache.clear() def get_gen(self) -> Iterator[tuple[str, Any]]: for key, data in self._cache.items(): yield key, data.value @property def size(self) -> int: return len(self._cache) ================================================ FILE: gokart/in_memory/target.py ================================================ from __future__ import annotations from datetime import datetime from typing import Any from gokart.in_memory.repository import InMemoryCacheRepository from gokart.required_task_output import RequiredTaskOutput from gokart.target import TargetOnKart, TaskLockParams from gokart.utils import FlattenableItems _repository = InMemoryCacheRepository() class InMemoryTarget(TargetOnKart): def __init__(self, data_key: str, task_lock_param: TaskLockParams): if task_lock_param.should_task_lock: raise ValueError('Redis with `InMemoryTarget` is not currently supported.') self._data_key = data_key self._task_lock_params = task_lock_param def _exists(self) -> bool: return _repository.has(self._data_key) def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params def _load(self) -> Any: return _repository.get_value(self._data_key) def _dump( self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: return _repository.set_value(self._data_key, obj) def _remove(self) -> None: _repository.remove(self._data_key) def _last_modification_time(self) -> datetime: if not _repository.has(self._data_key): raise ValueError(f'No object(s) which id is {self._data_key} are stored before.') time = _repository.get_last_modification_time(self._data_key) return time def _path(self) -> str: # TODO: this module name `_path` migit not be appropriate return self._data_key def make_in_memory_target(target_key: str, task_lock_params: TaskLockParams) -> InMemoryTarget: return InMemoryTarget(target_key, task_lock_params) ================================================ FILE: gokart/info.py ================================================ from __future__ import annotations from logging import getLogger from typing import Any import luigi from gokart.task import TaskOnKart from gokart.tree.task_info import make_task_info_as_tree_str logger = getLogger(__name__) def make_tree_info( task: TaskOnKart[Any], indent: str = '', last: bool = True, details: bool = False, abbr: bool = True, visited_tasks: set[str] | None = None, ignore_task_names: list[str] | None = None, ) -> str: """ Return a string representation of the tasks, their statuses/parameters in a dependency tree format This function has moved to `gokart.tree.task_info.make_task_info_as_tree_str`. This code is remained for backward compatibility. Parameters ---------- - task: TaskOnKart Root task. - details: bool Whether or not to output details. - abbr: bool Whether or not to simplify tasks information that has already appeared. - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- - tree_info : str Formatted task dependency tree. """ return make_task_info_as_tree_str(task=task, details=details, abbr=abbr, ignore_task_names=ignore_task_names) class tree_info(TaskOnKart[Any]): mode: luigi.StrParameter = luigi.StrParameter(default='', description='This must be in ["simple", "all"].') output_path: luigi.StrParameter = luigi.StrParameter(default='tree.txt', description='Output file path.') def output(self): return self.make_target(self.output_path, use_unique_id=False) ================================================ FILE: gokart/mypy.py ================================================ """Plugin that provides support for gokart.TaskOnKart. This Code reuses the code from mypy.plugins.dataclasses https://github.com/python/mypy/blob/0753e2a82dad35034e000609b6e8daa37238bfaa/mypy/plugins/dataclasses.py """ from __future__ import annotations import re import sys import warnings from collections.abc import Callable, Iterator from dataclasses import dataclass from enum import Enum from typing import Any, Final, Literal import luigi from mypy.expandtype import expand_type from mypy.nodes import ( ARG_NAMED, ARG_NAMED_OPT, ArgKind, Argument, AssignmentStmt, Block, CallExpr, ClassDef, EllipsisExpr, Expression, IfStmt, JsonDict, MemberExpr, NameExpr, PlaceholderNode, RefExpr, Statement, TempNode, TypeInfo, Var, ) from mypy.options import Options from mypy.plugin import ClassDefContext, FunctionContext, Plugin, SemanticAnalyzerPluginInterface from mypy.plugins.common import ( add_method_to_class, deserialize_and_fixup_type, ) from mypy.server.trigger import make_wildcard_trigger from mypy.state import state from mypy.typeops import map_type_from_supertype from mypy.types import ( AnyType, Instance, NoneType, Type, TypeOfAny, UnionType, ) from mypy.typevars import fill_typevars METADATA_TAG: Final[str] = 'task_on_kart' PARAMETER_FULLNAME_MATCHER: Final = re.compile(r'^(gokart|luigi)(\.parameter)?\.\w*Parameter$') PARAMETER_TMP_MATCHER: Final = re.compile(r'^\w*Parameter$') class PluginOptions(Enum): DISALLOW_MISSING_PARAMETERS = 'disallow_missing_parameters' @dataclass class TaskOnKartPluginOptions: # Whether to error on missing parameters in the constructor. # Some projects use luigi.Config to set parameters, which does not require parameters to be explicitly passed to the constructor. disallow_missing_parameters: bool = False @classmethod def _parse_toml(cls, config_file: str) -> dict[str, Any]: if sys.version_info >= (3, 11): import tomllib as toml_ else: try: import tomli as toml_ except ImportError: # pragma: no cover warnings.warn('install tomli to parse pyproject.toml under Python 3.10', stacklevel=1) return {} with open(config_file, 'rb') as f: return toml_.load(f) @classmethod def parse_config_file(cls, config_file: str) -> TaskOnKartPluginOptions: # TODO: support other configuration file formats if necessary. if not config_file.endswith('.toml'): warnings.warn('gokart mypy plugin can be configured by pyproject.toml', stacklevel=1) return cls() config = cls._parse_toml(config_file) gokart_plugin_config = config.get('tool', {}).get('gokart-mypy', {}) disallow_missing_parameters = gokart_plugin_config.get(PluginOptions.DISALLOW_MISSING_PARAMETERS.value, False) if not isinstance(disallow_missing_parameters, bool): raise ValueError(f'{PluginOptions.DISALLOW_MISSING_PARAMETERS.value} must be a boolean value') return cls(disallow_missing_parameters=disallow_missing_parameters) class TaskOnKartPlugin(Plugin): def __init__(self, options: Options) -> None: super().__init__(options) if options.config_file is not None: self._options = TaskOnKartPluginOptions.parse_config_file(options.config_file) else: self._options = TaskOnKartPluginOptions() def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None: # The following gathers attributes from gokart.TaskOnKart such as `workspace_directory` # the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`. # # NOTE: `gokart.task.luigi.Task` condition is required for the release of luigi versions without py.typed if fullname in {'gokart.task.luigi.Task', 'luigi.task.Task'}: return self._task_on_kart_class_maker_callback sym = self.lookup_fully_qualified(fullname) if sym and isinstance(sym.node, TypeInfo): if any(base.fullname == 'gokart.task.TaskOnKart' for base in sym.node.mro): return self._task_on_kart_class_maker_callback return None def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None: """Adjust the return type of the `Parameters` function.""" if PARAMETER_FULLNAME_MATCHER.match(fullname): return self._task_on_kart_parameter_field_callback return None def _task_on_kart_class_maker_callback(self, ctx: ClassDefContext) -> None: transformer = TaskOnKartTransformer(ctx.cls, ctx.reason, ctx.api, self._options) transformer.transform() def _task_on_kart_parameter_field_callback(self, ctx: FunctionContext) -> Type: """Extract the type of the `default` argument from the Field function, and use it as the return type. In particular: * Retrieve the type of the argument which is specified, and use it as return type for the function. * If no default argument is specified, return AnyType with unannotated type instead of parameter types like `luigi.Parameter()` This makes mypy avoid conflict between the type annotation and the parameter type. e.g. ```python foo: int = luigi.IntParameter() ``` """ try: default_idx = ctx.callee_arg_names.index('default') # if no `default` argument is found, return AnyType with unannotated type. except ValueError: return AnyType(TypeOfAny.unannotated) default_args = ctx.args[default_idx] if default_args: default_type = ctx.arg_types[0][0] default_arg = default_args[0] # Fallback to default Any type if the field is required if not isinstance(default_arg, EllipsisExpr): return default_type # NOTE: This is a workaround to avoid the error between type annotation and parameter type. # As the following code snippet, the type of `foo` is `int` but the assigned value is `luigi.IntParameter()`. # foo: int = luigi.IntParameter() # TODO: infer mypy type from the parameter type. return AnyType(TypeOfAny.unannotated) class TaskOnKartAttribute: def __init__( self, name: str, has_default: bool, line: int, column: int, type: Type | None, info: TypeInfo, api: SemanticAnalyzerPluginInterface, options: TaskOnKartPluginOptions, ) -> None: self.name = name self.has_default = has_default self.line = line self.column = column self.type = type # Type as __init__ argument self.info = info self._api = api self._options = options def to_argument(self, current_info: TypeInfo, *, of: Literal['__init__',]) -> Argument: if of == '__init__': arg_kind = self._get_arg_kind_by_options() return Argument( variable=self.to_var(current_info), type_annotation=self.expand_type(current_info), initializer=EllipsisExpr() if self.has_default else None, # Only used by stubgen kind=arg_kind, ) def expand_type(self, current_info: TypeInfo) -> Type | None: if self.type is not None and self.info.self_type is not None: # In general, it is not safe to call `expand_type()` during semantic analysis, # however this plugin is called very late, so all types should be fully ready. # Also, it is tricky to avoid eager expansion of Self types here (e.g. because # we serialize attributes). with state.strict_optional_set(self._api.options.strict_optional): return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)}) return self.type def to_var(self, current_info: TypeInfo) -> Var: return Var(self.name, self.expand_type(current_info)) def serialize(self) -> JsonDict: assert self.type return { 'name': self.name, 'has_default': self.has_default, 'line': self.line, 'column': self.column, 'type': self.type.serialize(), } @classmethod def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface, options: TaskOnKartPluginOptions) -> TaskOnKartAttribute: data = data.copy() typ = deserialize_and_fixup_type(data.pop('type'), api) return cls(type=typ, info=info, **data, api=api, options=options) def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None: """Expands type vars in the context of a subtype when an attribute is inherited from a generic super type.""" if self.type is not None: with state.strict_optional_set(self._api.options.strict_optional): self.type = map_type_from_supertype(self.type, sub_type, self.info) def _get_arg_kind_by_options(self) -> Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: """Set the argument kind based on the options. if `disallow_missing_parameters` is True, the argument kind is `ARG_NAMED` when the attribute has no default value. This means the that all the parameters are passed to the constructor as keyword-only arguments. Returns: Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: The argument kind. """ if not self._options.disallow_missing_parameters: return ARG_NAMED_OPT if self.has_default: return ARG_NAMED_OPT # required parameter return ARG_NAMED class TaskOnKartTransformer: """Implement the behavior of gokart.TaskOnKart.""" def __init__( self, cls: ClassDef, reason: Expression | Statement, api: SemanticAnalyzerPluginInterface, options: TaskOnKartPluginOptions, ) -> None: self._cls = cls self._reason = reason self._api = api self._options = options def transform(self) -> bool: """Apply all the necessary transformations to the underlying gokart.TaskOnKart""" info = self._cls.info attributes = self.collect_attributes() if attributes is None: # Some definitions are not ready. We need another pass. return False for attr in attributes: if attr.type is None: return False # If there are no attributes, it may be that the semantic analyzer has not # processed them yet. In order to work around this, we can simply skip generating # __init__ if there are no attributes, because if the user truly did not define any, # then the object default __init__ with an empty signature will be present anyway. if ('__init__' not in info.names or info.names['__init__'].plugin_generated) and attributes: args = [attr.to_argument(info, of='__init__') for attr in attributes] add_method_to_class(self._api, self._cls, '__init__', args=args, return_type=NoneType()) info.metadata[METADATA_TAG] = { 'attributes': [attr.serialize() for attr in attributes], } return True def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]: for body in stmt.body: if not body.is_unreachable: yield from self._get_assignment_statements_from_block(body) if stmt.else_body is not None and not stmt.else_body.is_unreachable: yield from self._get_assignment_statements_from_block(stmt.else_body) def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]: for stmt in block.body: if isinstance(stmt, AssignmentStmt): yield stmt elif isinstance(stmt, IfStmt): yield from self._get_assignment_statements_from_if_statement(stmt) def collect_attributes(self) -> list[TaskOnKartAttribute] | None: """Collect all attributes declared in the task and its parents. All assignments of the form a: SomeType b: SomeOtherType = ... are collected. Return None if some base class hasn't been processed yet and thus we'll need to ask for another pass. """ cls = self._cls # First, collect attributes belonging to any class in the MRO, ignoring duplicates. # # We iterate through the MRO in reverse because attrs defined in the parent must appear # earlier in the attributes list than attrs defined in the child. # # However, we also want attributes defined in the subtype to override ones defined # in the parent. We can implement this via a dict without disrupting the attr order # because dicts preserve insertion order in Python 3.7+. found_attrs: dict[str, TaskOnKartAttribute] = {} for info in reversed(cls.info.mro[1:-1]): if METADATA_TAG not in info.metadata: continue # Each class depends on the set of attributes in its task_on_kart ancestors. self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname)) for data in info.metadata[METADATA_TAG]['attributes']: name: str = data['name'] attr = TaskOnKartAttribute.deserialize(info, data, self._api, self._options) # TODO: We shouldn't be performing type operations during the main # semantic analysis pass, since some TypeInfo attributes might # still be in flux. This should be performed in a later phase. attr.expand_typevar_from_subtype(cls.info) found_attrs[name] = attr sym_node = cls.info.names.get(name) if sym_node and sym_node.node and not isinstance(sym_node.node, Var): self._api.fail( 'TaskOnKart attribute may only be overridden by another attribute', sym_node.node, ) # Second, collect attributes belonging to the current class. current_attr_names: set[str] = set() for stmt in self._get_assignment_statements_from_block(cls.defs): if not is_parameter_call(stmt.rvalue): continue # a: int, b: str = 1, 'foo' is not supported syntax so we # don't have to worry about it. lhs = stmt.lvalues[0] if not isinstance(lhs, NameExpr): continue sym = cls.info.names.get(lhs.name) if sym is None: # There was probably a semantic analysis error. continue node = sym.node assert not isinstance(node, PlaceholderNode) assert isinstance(node, Var) has_parameter_call, parameter_args = self._collect_parameter_args(stmt.rvalue) has_default = False # Ensure that something like x: int = field() is rejected # after an attribute with a default. if has_parameter_call: has_default = 'default' in parameter_args # All other assignments are already type checked. elif not isinstance(stmt.rvalue, TempNode): has_default = True if not has_default: # Make all non-default task_on_kart attributes implicit because they are de-facto # set on self in the generated __init__(), not in the class body. On the other # hand, we don't know how custom task_on_kart transforms initialize attributes, # so we don't treat them as implicit. This is required to support descriptors # (https://github.com/python/mypy/issues/14868). sym.implicit = True current_attr_names.add(lhs.name) with state.strict_optional_set(self._api.options.strict_optional): init_type = sym.type # infer Parameter type if init_type is None: init_type = self._infer_type_from_parameters(stmt.rvalue) found_attrs[lhs.name] = TaskOnKartAttribute( name=lhs.name, has_default=has_default, line=stmt.line, column=stmt.column, type=init_type, info=cls.info, api=self._api, options=self._options, ) return list(found_attrs.values()) def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Expression]]: """Returns a tuple where the first value represents whether or not the expression is a call to luigi.Parameter() or gokart.TaskInstanceParameter() and the second value is a dictionary of the keyword arguments that luigi.Parameter() or gokart.TaskInstanceParameter() was called with. """ if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr): args = {} for name, arg in zip(expr.arg_names, expr.args, strict=False): if name is None: # NOTE: this is a workaround to get default value from a parameter self._api.fail( 'Positional arguments are not allowed for parameters when using the mypy plugin. ' "Update your code to use named arguments, like luigi.Parameter(default='foo') instead of luigi.Parameter('foo')", expr, ) continue args[name] = arg return True, args return False, {} def _infer_type_from_parameters(self, parameter: Expression) -> Type | None: """ Generate default type from Parameter. For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type. """ parameter_name = _extract_parameter_name(parameter) if parameter_name is None: return None underlying_type: Type | None = None if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']: underlying_type = self._api.named_type('builtins.str', []) elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']: underlying_type = self._api.named_type('builtins.int', []) elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']: underlying_type = self._api.named_type('builtins.float', []) elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']: underlying_type = self._api.named_type('builtins.bool', []) elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']: underlying_type = self._api.named_type('datetime.date', []) elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']: underlying_type = self._api.named_type('datetime.datetime', []) elif parameter_name in ['luigi.parameter.TimeDeltaParameter']: underlying_type = self._api.named_type('datetime.timedelta', []) elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']: underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)]) elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']: underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']: underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)]) elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']: underlying_type = self._api.named_type('pathlib.Path', []) elif parameter_name in ['gokart.parameter.TaskInstanceParameter']: underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)]) elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']: underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])]) elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']: underlying_type = self._api.named_type('builtins.bool', []) elif parameter_name in ['luigi.parameter.NumericalParameter']: underlying_type = self._get_type_from_args(parameter, 'var_type') elif parameter_name in ['luigi.parameter.ChoiceParameter']: underlying_type = self._get_type_from_args(parameter, 'var_type') elif parameter_name in ['luigi.parameter.ChoiceListParameter']: base_type = self._get_type_from_args(parameter, 'var_type') if base_type is not None: underlying_type = self._api.named_type('builtins.tuple', [base_type]) elif parameter_name in ['luigi.parameter.EnumParameter']: underlying_type = self._get_type_from_args(parameter, 'enum') elif parameter_name in ['luigi.parameter.EnumListParameter']: base_type = self._get_type_from_args(parameter, 'enum') if base_type is not None: underlying_type = self._api.named_type('builtins.tuple', [base_type]) if underlying_type is None: return None # When parameter has Optional, it can be none value. if 'Optional' in parameter_name: return UnionType([underlying_type, NoneType()]) return underlying_type def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Type | None: """ get type from parameter arguments. e.x) When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type. """ ok, args = self._collect_parameter_args(parameter) if not ok: return None if arg_key not in args: return None arg = args[arg_key] if not isinstance(arg, NameExpr): return None if not isinstance(arg.node, TypeInfo): return None return Instance(arg.node, []) def is_parameter_call(expr: Expression) -> bool: """Checks if the expression is a call to luigi.Parameter()""" parameter_name = _extract_parameter_name(expr) if parameter_name is None: return False return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None def _extract_parameter_name(expr: Expression) -> str | None: """Extract name if the expression is a call to luigi.Parameter()""" if not isinstance(expr, CallExpr): return None callee = expr.callee if isinstance(callee, MemberExpr): type_info = callee.node if type_info is None and isinstance(callee.expr, NameExpr): return f'{callee.expr.name}.{callee.name}' elif isinstance(callee, NameExpr): type_info = callee.node else: return None if isinstance(type_info, TypeInfo): return type_info.fullname # Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1. # https://github.com/spotify/luigi/pull/3297 # With the following code, we can't assume correctly. # # from luigi import Parameter # class MyTask(gokart.TaskOnKart): # param = Parameter() if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1': return type_info.name return None def plugin(version: str) -> type[Plugin]: return TaskOnKartPlugin ================================================ FILE: gokart/object_storage.py ================================================ from __future__ import annotations from datetime import datetime from typing import cast import luigi import luigi.contrib.gcs import luigi.contrib.s3 from luigi.format import Format from gokart.gcs_config import GCSConfig from gokart.gcs_zip_client import GCSZipClient from gokart.s3_config import S3Config from gokart.s3_zip_client import S3ZipClient from gokart.zip_client import ZipClient object_storage_path_prefix = ['s3://', 'gs://'] class ObjectStorage: @staticmethod def if_object_storage_path(path: str) -> bool: for prefix in object_storage_path_prefix: if path.startswith(prefix): return True return False @staticmethod def get_object_storage_target(path: str, format: Format) -> luigi.target.FileSystemTarget: if path.startswith('s3://'): return luigi.contrib.s3.S3Target(path, client=S3Config().get_s3_client(), format=format) elif path.startswith('gs://'): return luigi.contrib.gcs.GCSTarget(path, client=GCSConfig().get_gcs_client(), format=format) else: raise @staticmethod def exists(path: str) -> bool: if path.startswith('s3://'): return cast(bool, S3Config().get_s3_client().exists(path)) elif path.startswith('gs://'): return cast(bool, GCSConfig().get_gcs_client().exists(path)) else: raise @staticmethod def get_timestamp(path: str) -> datetime: if path.startswith('s3://'): return cast(datetime, S3Config().get_s3_client().get_key(path).last_modified) elif path.startswith('gs://'): # for gcs object # should PR to luigi bucket, obj = GCSConfig().get_gcs_client()._path_to_bucket_and_key(path) result = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute() return cast(datetime, result['updated']) else: raise @staticmethod def get_zip_client(file_path: str, temporary_directory: str) -> ZipClient: if file_path.startswith('s3://'): return S3ZipClient(file_path=file_path, temporary_directory=temporary_directory) elif file_path.startswith('gs://'): return GCSZipClient(file_path=file_path, temporary_directory=temporary_directory) else: raise @staticmethod def is_buffered_reader(file: object) -> bool: return not isinstance(file, luigi.contrib.s3.ReadableS3File) ================================================ FILE: gokart/pandas_type_config.py ================================================ from __future__ import annotations from abc import abstractmethod from logging import getLogger from typing import Any import luigi import numpy as np import pandas as pd from luigi.task_register import Register logger = getLogger(__name__) class PandasTypeError(Exception): """Raised when the type of the pandas DataFrame column is not as expected.""" class PandasTypeConfig(luigi.Config): @classmethod @abstractmethod def type_dict(cls) -> dict[str, Any]: pass @classmethod def check(cls, df: pd.DataFrame) -> None: for column_name, column_type in cls.type_dict().items(): cls._check_column(df, column_name, column_type) @classmethod def _check_column(cls, df: pd.DataFrame, column_name: str, column_type: type) -> None: if column_name not in df.columns: return if not np.all(list(map(lambda x: isinstance(x, column_type), df[column_name]))): not_match = next(filter(lambda x: not isinstance(x, column_type), df[column_name])) raise PandasTypeError(f'expected type is "{column_type}", but "{type(not_match)}" is passed in column "{column_name}".') class PandasTypeConfigMap(luigi.Config): """To initialize this class only once, this inherits luigi.Config.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) task_names = Register.task_names() task_classes = [Register.get_task_cls(task_name) for task_name in task_names] self._map = { task_class.task_namespace: task_class for task_class in task_classes if issubclass(task_class, PandasTypeConfig) and task_class != PandasTypeConfig } def check(self, obj: Any, task_namespace: str) -> None: if isinstance(obj, pd.DataFrame) and task_namespace in self._map: self._map[task_namespace].check(obj) ================================================ FILE: gokart/parameter.py ================================================ from __future__ import annotations import bz2 import datetime import json import sys from logging import getLogger from typing import Any, Generic, Protocol, TypeVar if sys.version_info >= (3, 11): from typing import Unpack else: from typing_extensions import Unpack from warnings import warn import luigi from luigi import task_register try: from luigi.parameter import _no_value, _NoValueType, _ParameterKwargs except ImportError: _no_value = None # type: ignore[assignment] _NoValueType = type(None) # type: ignore[assignment,misc] _ParameterKwargs = dict # type: ignore[assignment,misc] import gokart logger = getLogger(__name__) TASK_ON_KART_TYPE = TypeVar('TASK_ON_KART_TYPE', bound='gokart.TaskOnKart') # type: ignore class TaskInstanceParameter(luigi.Parameter[TASK_ON_KART_TYPE], Generic[TASK_ON_KART_TYPE]): def __init__( self, expected_type: type[TASK_ON_KART_TYPE] | None = None, default: TASK_ON_KART_TYPE | _NoValueType = _no_value, **kwargs: Unpack[_ParameterKwargs], ): if expected_type is None: self.expected_type: type = gokart.TaskOnKart elif isinstance(expected_type, type): self.expected_type = expected_type else: raise TypeError(f'expected_type must be a type, not {type(expected_type)}') super().__init__(default=default, **kwargs) @staticmethod def _recursive(param_dict): params = param_dict['params'] task_cls = task_register.Register.get_task_cls(param_dict['type']) for key, value in task_cls.get_params(): if key in params: params[key] = value.parse(params[key]) return task_cls(**params) @staticmethod def _recursive_decompress(s): s = dict(luigi.DictParameter().parse(s)) if 'params' in s: s['params'] = TaskInstanceParameter._recursive_decompress(bz2.decompress(bytes.fromhex(s['params'])).decode()) return s def parse(self, s): if isinstance(s, str): s = self._recursive_decompress(s) return self._recursive(s) def serialize(self, x): params = bz2.compress(json.dumps(x.to_str_params(only_significant=True)).encode()).hex() values = dict(type=x.get_task_family(), params=params) return luigi.DictParameter().serialize(values) def _warn_on_wrong_param_type(self, param_name, param_value): if not isinstance(param_value, self.expected_type): raise TypeError(f'{param_value} is not an instance of {self.expected_type}') class _TaskInstanceEncoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, luigi.Task): return TaskInstanceParameter().serialize(obj) # Let the base class default method raise the TypeError return json.JSONEncoder.default(self, obj) class ListTaskInstanceParameter(luigi.Parameter[list[TASK_ON_KART_TYPE]], Generic[TASK_ON_KART_TYPE]): def __init__( self, expected_elements_type: type[TASK_ON_KART_TYPE] | None = None, default: list[TASK_ON_KART_TYPE] | _NoValueType = _no_value, **kwargs: Unpack[_ParameterKwargs], ): if expected_elements_type is None: self.expected_elements_type: type = gokart.TaskOnKart elif isinstance(expected_elements_type, type): self.expected_elements_type = expected_elements_type else: raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}') super().__init__(default=default, **kwargs) def parse(self, s): return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))] def serialize(self, x): return json.dumps(x, cls=_TaskInstanceEncoder) def _warn_on_wrong_param_type(self, param_name, param_value): for v in param_value: if not isinstance(v, self.expected_elements_type): raise TypeError(f'{v} is not an instance of {self.expected_elements_type}') class ExplicitBoolParameter(luigi.BoolParameter): def __init__(self, *args, **kwargs): luigi.Parameter.__init__(self, *args, **kwargs) def _parser_kwargs(self, *args, **kwargs): # type: ignore return luigi.Parameter._parser_kwargs(*args, *kwargs) T = TypeVar('T') class Serializable(Protocol): def gokart_serialize(self) -> str: """Implement this method to serialize the object as an parameter You can omit some fields from results of serialization if you want to ignore changes of them """ ... @classmethod def gokart_deserialize(cls: type[T], s: str) -> T: """Implement this method to deserialize the object from a string""" ... S = TypeVar('S', bound=Serializable) class SerializableParameter(luigi.Parameter[S], Generic[S]): def __init__(self, object_type: type[S], *args: Any, **kwargs: Any) -> None: self._object_type = object_type super().__init__(*args, **kwargs) def parse(self, s: str) -> S: return self._object_type.gokart_deserialize(s) def serialize(self, x: S) -> str: return x.gokart_serialize() class ZonedDateSecondParameter(luigi.Parameter[datetime.datetime]): """ ZonedDateSecondParameter supports a datetime.datetime object with timezone information. A ZonedDateSecondParameter is a `ISO 8601 `_ formatted date, time specified to the second and timezone. For example, ``2013-07-10T19:07:38+09:00`` specifies July 10, 2013 at 19:07:38 +09:00. The separator `:` can be omitted for Python3.11 and later. """ def __init__(self, **kwargs): super().__init__(**kwargs) def parse(self, s): # special character 'Z' is replaced with '+00:00' # because Python 3.11 and later support fromisoformat with Z at the end of the string. if s.endswith('Z'): s = s[:-1] + '+00:00' dt = datetime.datetime.fromisoformat(s) if dt.tzinfo is None: warn('The input does not have timezone information. Please consider using luigi.DateSecondParameter instead.', stacklevel=1) return dt def serialize(self, dt): return dt.isoformat() def normalize(self, dt): # override _DatetimeParameterBase.normalize to avoid do nothing to normalize except removing microsecond. # microsecond is removed because the number of digits of microsecond is not fixed. # See also luigi's implementation https://github.com/spotify/luigi/blob/v3.6.0/luigi/parameter.py#L612 return dt.replace(microsecond=0) ================================================ FILE: gokart/py.typed ================================================ ================================================ FILE: gokart/required_task_output.py ================================================ from dataclasses import dataclass @dataclass class RequiredTaskOutput: task_name: str output_path: str def serialize(self) -> dict[str, str]: return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path} ================================================ FILE: gokart/run.py ================================================ from __future__ import annotations import logging import os import sys from logging import getLogger from typing import Any import luigi import luigi.cmdline import luigi.cmdline_parser import luigi.execution_summary import luigi.interface import luigi.retcodes import luigi.setup_logging from luigi.cmdline_parser import CmdlineParser import gokart import gokart.slack from gokart.build import WorkerSchedulerFactory from gokart.object_storage import ObjectStorage logger = getLogger(__name__) def _run_tree_info(cmdline_args, details): with CmdlineParser.global_instance(cmdline_args) as cp: gokart.tree_info().output().dump(gokart.make_tree_info(cp.get_task_obj(), details=details)) def _try_tree_info(cmdline_args): with CmdlineParser.global_instance(cmdline_args): mode = gokart.tree_info().mode output_path = gokart.tree_info().output().path() # do nothing if `mode` is empty. if mode == '': return # output tree info and exit. if mode == 'simple': _run_tree_info(cmdline_args, details=False) elif mode == 'all': _run_tree_info(cmdline_args, details=True) else: raise ValueError(f'--tree-info-mode must be "simple" or "all", but "{mode}" is passed.') logger.info(f'output tree info: {output_path}') sys.exit() def _try_to_delete_unnecessary_output_file(cmdline_args: list[str]) -> None: with CmdlineParser.global_instance(cmdline_args) as cp: task: gokart.TaskOnKart[Any] = cp.get_task_obj() if task.delete_unnecessary_output_files: if ObjectStorage.if_object_storage_path(task.workspace_directory): logger.info('delete-unnecessary-output-files is not support s3/gcs.') else: gokart.delete_local_unnecessary_outputs(task) sys.exit() def _try_get_slack_api(cmdline_args: list[str]) -> gokart.slack.SlackAPI | None: with CmdlineParser.global_instance(cmdline_args): config = gokart.slack.SlackConfig() token = os.getenv(config.token_name, '') channel = config.channel to_user = config.to_user if token and channel: logger.info('Slack notification is activated.') return gokart.slack.SlackAPI(token=token, channel=channel, to_user=to_user) logger.info('Slack notification is not activated.') return None def _try_to_send_event_summary_to_slack( slack_api: gokart.slack.SlackAPI | None, event_aggregator: gokart.slack.EventAggregator, cmdline_args: list[str] ) -> None: if slack_api is None: # do nothing return options = gokart.slack.SlackConfig() with CmdlineParser.global_instance(cmdline_args) as cp: task = cp.get_task_obj() tree_info = gokart.make_tree_info(task, details=True) if options.send_tree_info else 'Please add SlackConfig.send_tree_info to include tree-info' task_name = type(task).__name__ comment = f'Report of {task_name}' + os.linesep + event_aggregator.get_summary() content = os.linesep.join(['===== Event List ====', event_aggregator.get_event_list(), os.linesep, '==== Tree Info ====', tree_info]) slack_api.send_snippet(comment=comment, title='event.txt', content=content) def _run_with_retcodes(argv): """run_with_retcodes equivalent that uses gokart's WorkerSchedulerFactory.""" retcode_logger = logging.getLogger('luigi-interface') with luigi.cmdline_parser.CmdlineParser.global_instance(argv): retcodes = luigi.retcodes.retcode() worker = None try: worker = luigi.interface._run(argv, worker_scheduler_factory=WorkerSchedulerFactory()).worker except luigi.interface.PidLockAlreadyTakenExit: sys.exit(retcodes.already_running) except Exception: env_params = luigi.interface.core() luigi.setup_logging.InterfaceLogging.setup(env_params) retcode_logger.exception('Uncaught exception in luigi') sys.exit(retcodes.unhandled_exception) with luigi.cmdline_parser.CmdlineParser.global_instance(argv): task_sets = luigi.execution_summary._summary_dict(worker) root_task = luigi.execution_summary._root_task(worker) non_empty_categories = {k: v for k, v in task_sets.items() if v}.keys() def has(status): assert status in luigi.execution_summary._ORDERED_STATUSES return status in non_empty_categories codes_and_conds = ( (retcodes.missing_data, has('still_pending_ext')), (retcodes.task_failed, has('failed')), (retcodes.already_running, has('run_by_other_worker')), (retcodes.scheduling_error, has('scheduling_error')), (retcodes.not_run, has('not_run')), ) expected_ret_code = max(code * (1 if cond else 0) for code, cond in codes_and_conds) if expected_ret_code == 0 and root_task not in task_sets['completed'] and root_task not in task_sets['already_done']: sys.exit(retcodes.not_run) else: sys.exit(expected_ret_code) def run(cmdline_args=None, set_retcode=True): cmdline_args = cmdline_args or sys.argv[1:] if set_retcode: luigi.retcodes.retcode.already_running = 10 # type: ignore luigi.retcodes.retcode.missing_data = 20 # type: ignore luigi.retcodes.retcode.not_run = 30 # type: ignore luigi.retcodes.retcode.task_failed = 40 # type: ignore luigi.retcodes.retcode.scheduling_error = 50 # type: ignore _try_tree_info(cmdline_args) _try_to_delete_unnecessary_output_file(cmdline_args) gokart.testing.try_to_run_test_for_empty_data_frame(cmdline_args) slack_api = _try_get_slack_api(cmdline_args) event_aggregator = gokart.slack.EventAggregator() try: event_aggregator.set_handlers() _run_with_retcodes(cmdline_args) except SystemExit as e: _try_to_send_event_summary_to_slack(slack_api, event_aggregator, cmdline_args) sys.exit(e.code) ================================================ FILE: gokart/s3_config.py ================================================ from __future__ import annotations import os import luigi import luigi.contrib.s3 class S3Config(luigi.Config): aws_access_key_id_name = luigi.Parameter(default='AWS_ACCESS_KEY_ID', description='AWS access key id environment variable.') aws_secret_access_key_name = luigi.Parameter(default='AWS_SECRET_ACCESS_KEY', description='AWS secret access key environment variable.') _client = None def get_s3_client(self) -> luigi.contrib.s3.S3Client: if self._client is None: # use cache as like singleton object self._client = self._get_s3_client() return self._client def _get_s3_client(self) -> luigi.contrib.s3.S3Client: return luigi.contrib.s3.S3Client( aws_access_key_id=os.environ.get(self.aws_access_key_id_name), aws_secret_access_key=os.environ.get(self.aws_secret_access_key_name) ) ================================================ FILE: gokart/s3_zip_client.py ================================================ from __future__ import annotations import os import shutil from typing import cast from gokart.s3_config import S3Config from gokart.zip_client import ZipClient, _unzip_file class S3ZipClient(ZipClient): def __init__(self, file_path: str, temporary_directory: str) -> None: self._file_path = file_path self._temporary_directory = temporary_directory self._client = S3Config().get_s3_client() def exists(self) -> bool: return cast(bool, self._client.exists(self._file_path)) def make_archive(self) -> None: extension = os.path.splitext(self._file_path)[1] if not os.path.exists(self._temporary_directory): # Check path existence since shutil.make_archive() of python 3.10+ does not check it. raise FileNotFoundError(f'Temporary directory {self._temporary_directory} is not found.') shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory) self._client.put(self._temporary_file_path(), self._file_path) def unpack_archive(self) -> None: os.makedirs(self._temporary_directory, exist_ok=True) self._client.get(self._file_path, self._temporary_file_path()) _unzip_file(fp=self._temporary_file_path(), extract_dir=self._temporary_directory) def remove(self) -> None: self._client.remove(self._file_path) @property def path(self) -> str: return self._file_path def _temporary_file_path(self): extension = os.path.splitext(self._file_path)[1] base_name = self._temporary_directory if base_name.endswith('/'): base_name = base_name[:-1] return base_name + extension ================================================ FILE: gokart/slack/__init__.py ================================================ from gokart.slack.event_aggregator import EventAggregator from gokart.slack.slack_api import SlackAPI from gokart.slack.slack_config import SlackConfig from .slack_api import ChannelListNotLoadedError, ChannelNotFoundError, FileNotUploadedError __all__ = [ 'ChannelListNotLoadedError', 'ChannelNotFoundError', 'FileNotUploadedError', 'EventAggregator', 'SlackAPI', 'SlackConfig', ] ================================================ FILE: gokart/slack/event_aggregator.py ================================================ from __future__ import annotations import os from logging import getLogger from typing import Any, TypedDict import luigi logger = getLogger(__name__) class FailureEvent(TypedDict): task: str exception: str class EventAggregator: def __init__(self) -> None: self._success_events: list[str] = [] self._failure_events: list[FailureEvent] = [] def set_handlers(self): handlers = [(luigi.Event.SUCCESS, self._success), (luigi.Event.FAILURE, self._failure)] for event, handler in handlers: luigi.Task.event_handler(event)(handler) def get_summary(self) -> str: return f'Success: {len(self._success_events)}; Failure: {len(self._failure_events)}' def get_event_list(self) -> str: message = '' if len(self._failure_events) != 0: failure_message = os.linesep.join([f'Task: {failure["task"]}; Exception: {failure["exception"]}' for failure in self._failure_events]) message += '---- Failure Tasks ----' + os.linesep + failure_message if len(self._success_events) != 0: success_message = os.linesep.join(self._success_events) message += '---- Success Tasks ----' + os.linesep + success_message if message == '': message = 'Tasks were not executed.' return message def _success(self, task): self._success_events.append(self._task_to_str(task)) def _failure(self, task, exception): failure: FailureEvent = {'task': self._task_to_str(task), 'exception': str(exception)} self._failure_events.append(failure) @staticmethod def _task_to_str(task: Any) -> str: return f'{type(task).__name__}:[{task.make_unique_id()}]' ================================================ FILE: gokart/slack/slack_api.py ================================================ from __future__ import annotations from logging import getLogger import slack_sdk logger = getLogger(__name__) class ChannelListNotLoadedError(RuntimeError): pass class ChannelNotFoundError(RuntimeError): pass class FileNotUploadedError(RuntimeError): pass class SlackAPI: def __init__(self, token: str, channel: str, to_user: str) -> None: self._client = slack_sdk.WebClient(token=token) self._channel_id = self._get_channel_id(channel) self._to_user = to_user if to_user == '' or to_user.startswith('@') else '@' + to_user def _get_channel_id(self, channel_name): params = {'exclude_archived': True, 'limit': 100} try: for channels in self._client.conversations_list(params=params): if not channels: raise ChannelListNotLoadedError('Channel list is empty.') for channel in channels.get('channels', []): if channel['name'] == channel_name: return channel['id'] raise ChannelNotFoundError(f'Channel {channel_name} is not found in public channels.') except Exception as e: logger.warning(f'The job will start without slack notification: {e}') def send_snippet(self, comment, title, content): try: request_body = dict( channels=self._channel_id, initial_comment=f'<{self._to_user}> {comment}' if self._to_user else comment, content=content, title=title ) response = self._client.api_call('files.upload', data=request_body) if not response['ok']: raise FileNotUploadedError(f'Error while uploading file. The error reason is "{response["error"]}".') except Exception as e: logger.warning(f'Failed to send slack notification: {e}') ================================================ FILE: gokart/slack/slack_config.py ================================================ from __future__ import annotations import luigi class SlackConfig(luigi.Config): token_name = luigi.Parameter(default='SLACK_TOKEN', description='slack token environment variable.') channel = luigi.Parameter(default='', significant=False, description='channel name for notification.') to_user = luigi.Parameter(default='', significant=False, description='Optional; user name who is supposed to be mentioned.') send_tree_info = luigi.BoolParameter( default=False, significant=False, description='When this option is true, the dependency tree of tasks is included in send message.' 'It is recommended to set false to this option when notification takes long time.', ) ================================================ FILE: gokart/target.py ================================================ from __future__ import annotations import hashlib import os import shutil from abc import abstractmethod from datetime import datetime from glob import glob from logging import getLogger from typing import Any, cast import luigi import numpy as np import pandas as pd from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock from gokart.file_processor import FileProcessor, make_file_processor from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient from gokart.object_storage import ObjectStorage from gokart.required_task_output import RequiredTaskOutput from gokart.utils import FlattenableItems from gokart.zip_client_util import make_zip_client logger = getLogger(__name__) class TargetOnKart(luigi.Target): def exists(self) -> bool: return self._exists() def load(self) -> Any: return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())() def dump( self, obj: Any, lock_at_dump: bool = True, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: if lock_at_dump: wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)( obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs, ) else: self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs) def remove(self) -> None: if self.exists(): wrap_remove_with_lock(self._remove, task_lock_params=self._get_task_lock_params())() def last_modification_time(self) -> datetime: return self._last_modification_time() def path(self) -> str: return self._path() @abstractmethod def _exists(self) -> bool: pass @abstractmethod def _get_task_lock_params(self) -> TaskLockParams: pass @abstractmethod def _load(self) -> Any: pass @abstractmethod def _dump( self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: pass @abstractmethod def _remove(self) -> None: pass @abstractmethod def _last_modification_time(self) -> datetime: pass @abstractmethod def _path(self) -> str: pass class SingleFileTarget(TargetOnKart): def __init__( self, target: luigi.target.FileSystemTarget, processor: FileProcessor, task_lock_params: TaskLockParams, ) -> None: self._target = target self._processor = processor self._task_lock_params = task_lock_params def _exists(self) -> bool: return cast(bool, self._target.exists()) def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params def _load(self) -> Any: with self._target.open('r') as f: return self._processor.load(f) def _dump( self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: with self._target.open('w') as f: self._processor.dump(obj, f) if self.path().startswith('gs://'): GCSObjectMetadataClient.add_task_state_labels( path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs ) def _remove(self) -> None: self._target.remove() def _last_modification_time(self) -> datetime: return _get_last_modification_time(self._target.path) def _path(self) -> str: return self._target.path class ModelTarget(TargetOnKart): def __init__( self, file_path: str, temporary_directory: str, load_function: Any, save_function: Any, task_lock_params: TaskLockParams, ) -> None: self._zip_client = make_zip_client(file_path, temporary_directory) self._temporary_directory = temporary_directory self._save_function = save_function self._load_function = load_function self._task_lock_params = task_lock_params def _exists(self) -> bool: return self._zip_client.exists() def _get_task_lock_params(self) -> TaskLockParams: return self._task_lock_params def _load(self) -> Any: self._zip_client.unpack_archive() self._load_function = self._load_function or make_target(self._load_function_path()).load() model = self._load_function(self._model_path()) self._remove_temporary_directory() return model def _dump( self, obj: Any, task_params: dict[str, str] | None = None, custom_labels: dict[str, str] | None = None, required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None, ) -> None: self._make_temporary_directory() self._save_function(obj, self._model_path()) make_target(self._load_function_path()).dump( self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs ) self._zip_client.make_archive() self._remove_temporary_directory() def _remove(self) -> None: self._zip_client.remove() def _last_modification_time(self) -> datetime: return _get_last_modification_time(self._zip_client.path) def _path(self) -> str: return self._zip_client.path def _model_path(self): return os.path.join(self._temporary_directory, 'model.pkl') def _load_function_path(self): return os.path.join(self._temporary_directory, 'load_function.pkl') def _remove_temporary_directory(self): shutil.rmtree(self._temporary_directory) def _make_temporary_directory(self): os.makedirs(self._temporary_directory, exist_ok=True) class LargeDataFrameProcessor: def __init__(self, max_byte: int): self.max_byte = int(max_byte) def save(self, df: pd.DataFrame, file_path: str) -> None: dir_path = os.path.dirname(file_path) os.makedirs(dir_path, exist_ok=True) if df.empty: df.to_pickle(os.path.join(dir_path, 'data_0.pkl')) return split_size = df.values.nbytes // self.max_byte + 1 logger.info(f'saving a large pdDataFrame with split_size={split_size}') for i, idx in list(enumerate(np.array_split(range(df.shape[0]), split_size))): df.iloc[idx[0] : idx[-1] + 1].to_pickle(os.path.join(dir_path, f'data_{i}.pkl')) @staticmethod def load(file_path: str) -> pd.DataFrame: dir_path = os.path.dirname(file_path) return pd.concat([pd.read_pickle(file_path) for file_path in glob(os.path.join(dir_path, 'data_*.pkl'))]) def _make_file_system_target(file_path: str, processor: FileProcessor | None = None, store_index_in_feather: bool = True) -> luigi.target.FileSystemTarget: processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) if ObjectStorage.if_object_storage_path(file_path): return ObjectStorage.get_object_storage_target(file_path, processor.format()) return luigi.LocalTarget(file_path, format=processor.format()) def _make_file_path(original_path: str, unique_id: str | None = None) -> str: if unique_id is not None: [base, extension] = os.path.splitext(original_path) return base + '_' + unique_id + extension return original_path def _get_last_modification_time(path: str) -> datetime: if ObjectStorage.if_object_storage_path(path): if ObjectStorage.exists(path): return ObjectStorage.get_timestamp(path) raise FileNotFoundError(f'No such file or directory: {path}') return datetime.fromtimestamp(os.path.getmtime(path)) def make_target( file_path: str, unique_id: str | None = None, processor: FileProcessor | None = None, task_lock_params: TaskLockParams | None = None, store_index_in_feather: bool = True, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather) file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather) return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params) def make_model_target( file_path: str, temporary_directory: str, save_function: Any, load_function: Any, unique_id: str | None = None, task_lock_params: TaskLockParams | None = None, ) -> TargetOnKart: _task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id) file_path = _make_file_path(file_path, unique_id) temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest()) return ModelTarget( file_path=file_path, temporary_directory=temporary_directory, save_function=save_function, load_function=load_function, task_lock_params=_task_lock_params, ) ================================================ FILE: gokart/task.py ================================================ from __future__ import annotations import functools import hashlib import inspect import os import random import types from collections.abc import Callable, Generator, Iterable from importlib import import_module from logging import getLogger from typing import Any, Generic, TypeVar, cast, overload import luigi import pandas as pd from luigi.parameter import ParameterVisibility import gokart import gokart.target from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock from gokart.file_processor import FileProcessor, make_file_processor from gokart.pandas_type_config import PandasTypeConfigMap from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter from gokart.required_task_output import RequiredTaskOutput from gokart.target import TargetOnKart from gokart.task_complete_check import task_complete_check_wrapper from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items logger = getLogger(__name__) T = TypeVar('T') K = TypeVar('K') # NOTE: inherited from AssertionError for backward compatibility (Formerly, Gokart raises that exception when a task dumps an empty DataFrame). class EmptyDumpError(AssertionError): """Raised when the task attempts to dump an empty DataFrame even though it is prohibited (``fail_on_empty_dump`` is set to True)""" class TaskOnKart(luigi.Task, Generic[T]): """ This is a wrapper class of luigi.Task. The key methods of a TaskOnKart are: * :py:meth:`make_target` - this makes output target with a relative file path. * :py:meth:`make_model_target` - this makes output target for models which generate multiple files to save. * :py:meth:`load` - this loads input files of this task. * :py:meth:`dump` - this save a object as output of this task. """ workspace_directory: luigi.Parameter[str] = luigi.Parameter( default='./resources/', description='A directory to set outputs on. Please use a path starts with s3:// when you use s3.', significant=False ) local_temporary_directory: luigi.Parameter[str] = luigi.Parameter( default='./resources/tmp/', description='A directory to save temporary files.', significant=False ) rerun: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, this task will run even if all output files exist.', significant=False ) strict_check: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, this task will not run only if all input and output files exist.', significant=False ) modification_time_check: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, this task will not run only if all input and output files exist,' ' and all input files are modified before output file are modified.', significant=False, ) serialized_task_definition_check: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, even if all outputs are present,this task will be executed if any changes have been made to the code.', significant=False, ) delete_unnecessary_output_files: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If this is true, delete unnecessary output files.', significant=False ) significant: luigi.BoolParameter = luigi.BoolParameter( default=True, description='If this is false, this task is not treated as a part of dependent tasks for the unique id.', significant=False ) fix_random_seed_methods: luigi.Parameter[tuple[str, ...]] = luigi.ListParameter( default=('random.seed', 'numpy.random.seed'), description='Fix random seed method list.', significant=False ) FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER = -42497368 fix_random_seed_value: luigi.Parameter[int] = luigi.IntParameter( default=FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER, description='Fix random seed method value.', significant=False ) # FIXME: should fix with OptionalIntParameter after newer luigi (https://github.com/spotify/luigi/pull/3079) will be released redis_host: luigi.Parameter[str | None] = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False) redis_port: luigi.OptionalIntParameter = luigi.OptionalIntParameter( default=None, # type: ignore description='Task lock check is deactivated, when None.', significant=False, ) redis_timeout: luigi.IntParameter = luigi.IntParameter( default=180, description='Redis lock will be released after `redis_timeout` seconds', significant=False ) fail_on_empty_dump: luigi.Parameter[bool] = ExplicitBoolParameter(default=False, description='Fail when task dumps empty DF', significant=False) store_index_in_feather: luigi.Parameter[bool] = ExplicitBoolParameter( default=True, description='Wether to store index when using feather as a output object.', significant=False ) cache_unique_id: luigi.Parameter[bool] = ExplicitBoolParameter(default=True, description='Cache unique id during runtime', significant=False) should_dump_supplementary_log_files: luigi.Parameter[bool] = ExplicitBoolParameter( default=True, description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \ Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.', significant=False, ) complete_check_at_run: luigi.Parameter[bool] = ExplicitBoolParameter( default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False ) should_lock_run: luigi.Parameter[bool] = ExplicitBoolParameter( default=False, significant=False, description='Whether to use redis lock or not at task run.' ) @property def priority(self): return random.Random().random() # seed is fixed, so we need to use random.Random().random() instead f random.random() def __init__(self, *args, **kwargs): self._add_configuration(kwargs, 'TaskOnKart') # 'This parameter is dumped into "workspace_directory/log/task_log/" when this task finishes with success.' self.task_log = dict() self.task_unique_id = None super().__init__(*args, **kwargs) self._rerun_state = self.rerun self._lock_at_dump = True # Cache to_str_params to avoid slow task creation in a deep task tree. # For example, gokart.build(RecursiveTask(dep=RecursiveTask(dep=RecursiveTask(dep=HelloWorldTask())))) results in O(n^2) calls to to_str_params. # However, @lru_cache cannot be used as a decorator because luigi.Task employs metaclass tricks. self.to_str_params = functools.lru_cache(maxsize=None)(self.to_str_params) # type: ignore[method-assign] if self.complete_check_at_run: self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore if self.should_lock_run: self._lock_at_dump = False assert self.redis_host is not None, 'redis_host must be set when should_lock_run is True.' assert self.redis_port is not None, 'redis_port must be set when should_lock_run is True.' task_lock_params = make_task_lock_params_for_run(task_self=self) self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore def input(self) -> FlattenableItems[TargetOnKart]: return cast(FlattenableItems[TargetOnKart], super().input()) def output(self) -> FlattenableItems[TargetOnKart]: return self.make_target() def requires(self) -> FlattenableItems[TaskOnKart[Any]]: tasks = self.make_task_instance_dictionary() return tasks or [] # when tasks is empty dict, then this returns empty list. def make_task_instance_dictionary(self) -> dict[str, TaskOnKart[Any]]: return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)} @staticmethod def is_task_on_kart(value): return isinstance(value, TaskOnKart) or (isinstance(value, tuple) and bool(value) and all([isinstance(v, TaskOnKart) for v in value])) @classmethod def _add_configuration(cls, kwargs, section): config = luigi.configuration.get_config() class_variables = dict(TaskOnKart.__dict__) class_variables.update(dict(cls.__dict__)) if section not in config: return for key, value in dict(config[section]).items(): if key not in kwargs and key in class_variables: kwargs[key] = class_variables[key].parse(value) def complete(self) -> bool: if self._rerun_state: for target in flatten(self.output()): target.remove() self._rerun_state = False return False is_completed = all([t.exists() for t in flatten(self.output())]) if self.strict_check or self.modification_time_check: requirements = flatten(self.requires()) inputs = flatten(self.input()) is_completed = is_completed and all([task.complete() for task in requirements]) and all([i.exists() for i in inputs]) if not self.modification_time_check or not is_completed or not self.input(): return is_completed return self._check_modification_time() def _check_modification_time(self) -> bool: common_path = set(t.path() for t in flatten(self.input())) & set(t.path() for t in flatten(self.output())) input_tasks = [t for t in flatten(self.input()) if t.path() not in common_path] output_tasks = [t for t in flatten(self.output()) if t.path() not in common_path] input_modification_time = max([target.last_modification_time() for target in input_tasks]) if input_tasks else None output_modification_time = min([target.last_modification_time() for target in output_tasks]) if output_tasks else None if input_modification_time is None or output_modification_time is None: return True # "=" must be required in the following statements, because some tasks use input targets as output targets. return input_modification_time <= output_modification_time def clone(self, cls=None, **kwargs): _SPECIAL_PARAMS = {'rerun', 'strict_check', 'modification_time_check'} if cls is None: cls = self.__class__ new_k = {} for param_name, _ in cls.get_params(): if param_name in kwargs: new_k[param_name] = kwargs[param_name] elif hasattr(self, param_name) and (param_name not in _SPECIAL_PARAMS): new_k[param_name] = getattr(self, param_name) return cls(**new_k) def make_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, processor: FileProcessor | None = None) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl') ) file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None # Auto-select processor based on type parameter if not provided if processor is None and relative_file_path is not None: processor = self._create_processor_for_dataframe_type(file_path) task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, raise_task_lock_exception_on_collision=False, ) return gokart.target.make_target( file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather ) def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor: df_type = get_dataframe_type_from_task(self) return make_file_processor(file_path, dataframe_type=df_type, store_index_in_feather=self.store_index_in_feather) def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte: int = int(2**26)) -> TargetOnKart: formatted_relative_file_path = ( relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip') ) file_path = os.path.join(self.workspace_directory, formatted_relative_file_path) unique_id = self.make_unique_id() if use_unique_id else None task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, raise_task_lock_exception_on_collision=False, ) return gokart.target.make_model_target( file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save, load_function=gokart.target.LargeDataFrameProcessor.load, task_lock_params=task_lock_params, ) def make_model_target( self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True ) -> TargetOnKart: """ Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on. :param relative_file_path: A file path to save. :param save_function: A function to save a model. This takes a model object and a file path. :param load_function: A function to load a model. This takes a file path and returns a model object. :param use_unique_id: If this is true, add an unique id to a file base name. """ file_path = os.path.join(self.workspace_directory, relative_file_path) assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.' unique_id = self.make_unique_id() if use_unique_id else None task_lock_params = make_task_lock_params( file_path=file_path, unique_id=unique_id, redis_host=self.redis_host, redis_port=self.redis_port, redis_timeout=self.redis_timeout, raise_task_lock_exception_on_collision=False, ) return gokart.target.make_model_target( file_path=file_path, temporary_directory=self.local_temporary_directory, unique_id=unique_id, save_function=save_function, load_function=load_function, task_lock_params=task_lock_params, ) @overload def load(self, target: None | str | TargetOnKart = None) -> Any: ... @overload def load(self, target: TaskOnKart[K]) -> K: ... @overload def load(self, target: list[TaskOnKart[K]]) -> list[K]: ... def load(self, target: None | str | TargetOnKart | TaskOnKart[K] | list[TaskOnKart[K]] = None) -> Any: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): return [_load(t) for t in targets] if isinstance(targets, dict): return {k: _load(t) for k, t in targets.items()} return targets.load() return _load(self._get_input_targets(target)) @overload def load_generator(self, target: None | str | TargetOnKart = None) -> Generator[Any, None, None]: ... @overload def load_generator(self, target: list[TaskOnKart[K]]) -> Generator[K, None, None]: ... def load_generator(self, target: None | str | TargetOnKart | list[TaskOnKart[K]] = None) -> Generator[Any, None, None]: def _load(targets): if isinstance(targets, list) or isinstance(targets, tuple): for t in targets: yield from _load(t) elif isinstance(targets, dict): for k, t in targets.items(): yield from {k: _load(t)} else: yield targets.load() return cast(Generator[Any, None, None], _load(self._get_input_targets(target))) @overload def dump(self, obj: T, target: None = None, custom_labels: dict[Any, Any] | None = None) -> None: ... @overload def dump(self, obj: Any, target: str | TargetOnKart, custom_labels: dict[Any, Any] | None = None) -> None: ... def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels: dict[str, Any] | None = None) -> None: PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace) if self.fail_on_empty_dump: if isinstance(obj, pd.DataFrame) and obj.empty: raise EmptyDumpError() required_task_outputs = map_flattenable_items( lambda task: map_flattenable_items(lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()), self.requires(), ) self._get_output_target(target).dump( obj, lock_at_dump=self._lock_at_dump, task_params=super().to_str_params(only_significant=True, only_public=True), custom_labels=custom_labels, required_task_outputs=required_task_outputs, ) @staticmethod def get_code(target_class: Any) -> set[str]: def has_sourcecode(obj): return inspect.ismethod(obj) or inspect.isfunction(obj) or inspect.isframe(obj) or inspect.iscode(obj) return {inspect.getsource(t) for _, t in inspect.getmembers(target_class, has_sourcecode)} def get_own_code(self): gokart_codes = self.get_code(TaskOnKart) own_codes = self.get_code(self) return ''.join(sorted(list(own_codes - gokart_codes))) def make_unique_id(self) -> str: unique_id = self.task_unique_id or self._make_hash_id() if self.cache_unique_id: self.task_unique_id = unique_id return unique_id def _make_hash_id(self) -> str: def _to_str_params(task): if isinstance(task, TaskOnKart): return str(task.make_unique_id()) if task.significant else None if not isinstance(task, luigi.Task): raise ValueError(f'Task.requires method returns {type(task)}. You should return luigi.Task.') return task.to_str_params(only_significant=True) dependencies = [_to_str_params(task) for task in flatten(self.requires())] dependencies = [d for d in dependencies if d is not None] dependencies.append(self.to_str_params(only_significant=True)) dependencies.append(self.__class__.__name__) if self.serialized_task_definition_check: dependencies.append(self.get_own_code()) return hashlib.md5(str(dependencies).encode()).hexdigest() def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart[Any] | list[TaskOnKart[Any]]) -> FlattenableItems[TargetOnKart]: if target is None: return self.input() if isinstance(target, str): input = self.input() assert isinstance(input, dict), f'input must be dict[str, TargetOnKart], but {type(input)} is passed.' result: FlattenableItems[TargetOnKart] = input[target] return result if isinstance(target, Iterable): return [self._get_input_targets(t) for t in target] if isinstance(target, TaskOnKart): requires_unique_ids = [task.make_unique_id() for task in flatten(self.requires())] assert target.make_unique_id() in requires_unique_ids, f'{target} should be in requires method' return target.output() return target def _get_output_target(self, target: None | str | TargetOnKart) -> TargetOnKart: if target is None: output = self.output() assert isinstance(output, TargetOnKart), f'output must be TargetOnKart, but {type(output)} is passed.' return output if isinstance(target, str): output = self.output() assert isinstance(output, dict), f'output must be dict[str, TargetOnKart], but {type(output)} is passed.' result = output[target] assert isinstance(result, TargetOnKart), f'output must be dict[str, TargetOnKart], but {type(output)} is passed.' return result return target def get_info(self, only_significant=False): params_str = {} params = dict(self.get_params()) for param_name, param_value in self.param_kwargs.items(): if (not only_significant) or params[param_name].significant: if isinstance(params[param_name], gokart.TaskInstanceParameter): params_str[param_name] = type(param_value).__name__ + '-' + param_value.make_unique_id() else: params_str[param_name] = params[param_name].serialize(param_value) return params_str def _get_task_log_target(self): return self.make_target(f'log/task_log/{type(self).__name__}.pkl') def get_task_log(self) -> dict[str, Any]: target = self._get_task_log_target() if self.task_log: return self.task_log if target.exists(): return cast(dict[Any, Any], self.load(target)) return dict() @luigi.Task.event_handler(luigi.Event.SUCCESS) def _dump_task_log(self): self.task_log['file_path'] = [target.path() for target in flatten(self.output())] if self.should_dump_supplementary_log_files: self.dump(self.task_log, self._get_task_log_target()) def _get_task_params_target(self): return self.make_target(f'log/task_params/{type(self).__name__}.pkl') def get_task_params(self) -> dict[str, Any]: target = self._get_task_log_target() if target.exists(): return cast(dict[Any, Any], self.load(target)) return dict() @luigi.Task.event_handler(luigi.Event.START) def _set_random_seed(self): if self.should_dump_supplementary_log_files: random_seed = self._get_random_seed() seed_methods = self.try_set_seed(list(self.fix_random_seed_methods), random_seed) self.dump({'seed': random_seed, 'seed_methods': seed_methods}, self._get_random_seeds_target()) def _get_random_seeds_target(self): return self.make_target(f'log/random_seed/{type(self).__name__}.pkl') @staticmethod def try_set_seed(methods: list[str], random_seed: int) -> list[str]: success_methods: list[str] = [] for method_name in methods: try: parts = method_name.split('.') m: Any = import_module(parts[0]) for x in parts[1:]: m = getattr(m, x) m(random_seed) success_methods.append(method_name) except ModuleNotFoundError: pass except AttributeError: pass return success_methods def _get_random_seed(self): if self.fix_random_seed_value and (not self.fix_random_seed_value == self.FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER): return self.fix_random_seed_value return int(self.make_unique_id(), 16) % (2**32 - 1) # maximum numpy.random.seed @luigi.Task.event_handler(luigi.Event.START) def _dump_task_params(self): if self.should_dump_supplementary_log_files: self.dump(self.to_str_params(only_significant=True), self._get_task_params_target()) def _get_processing_time_target(self): return self.make_target(f'log/processing_time/{type(self).__name__}.pkl') def get_processing_time(self) -> str: target = self._get_processing_time_target() if target.exists(): return cast(str, self.load(target)) return 'unknown' @luigi.Task.event_handler(luigi.Event.PROCESSING_TIME) def _dump_processing_time(self, processing_time): if self.should_dump_supplementary_log_files: self.dump(processing_time, self._get_processing_time_target()) @classmethod def restore(cls, unique_id): params = TaskOnKart().make_target(f'log/task_params/{cls.__name__}_{unique_id}.pkl', use_unique_id=False).load() return cls.from_str_params(params) @luigi.Task.event_handler(luigi.Event.FAILURE) def _log_unique_id(self, exception): logger.info(f'FAILURE:\n task name={type(self).__name__}\n unique id={self.make_unique_id()}') @luigi.Task.event_handler(luigi.Event.START) def _dump_module_versions(self): if self.should_dump_supplementary_log_files: self.dump(self._get_module_versions(), self._get_module_versions_target()) def _get_module_versions_target(self): return self.make_target(f'log/module_versions/{type(self).__name__}.txt') def _get_module_versions(self) -> str: module_versions = [] for x in set([x.split('.')[0] for x in globals().keys() if isinstance(x, types.ModuleType) and '_' not in x]): module = import_module(x) if '__version__' in dir(module): if isinstance(module.__version__, str): version = module.__version__.split(' ')[0] else: version = '.'.join([str(v) for v in module.__version__]) module_versions.append(f'{x}=={version}') return '\n'.join(module_versions) def __repr__(self): """ Build a task representation like `MyTask[aca2f28555dadd0f1e3dee3d4b973651](param1=1.5, param2='5', data_task=DataTask(c1f5d06aa580c5761c55bd83b18b0b4e))` """ return self._get_task_string() def __str__(self): """ Build a human-readable task representation like `MyTask[aca2f28555dadd0f1e3dee3d4b973651](param1=1.5, param2='5', data_task=DataTask(c1f5d06aa580c5761c55bd83b18b0b4e))` This includes only public parameters """ return self._get_task_string(only_public=True) def _get_task_string(self, only_public=False): """ Convert a task representation like `MyTask(param1=1.5, param2='5', data_task=DataTask(id=35tyi))` """ params = self.get_params() param_values = self.get_param_values(params, [], self.param_kwargs) # Build up task id repr_parts = [] param_objs = dict(params) for param_name, param_value in param_values: param_obj = param_objs[param_name] if param_obj.significant and ((not only_public) or param_obj.visibility == ParameterVisibility.PUBLIC): repr_parts.append(f'{param_name}={self._make_representation(param_obj, param_value)}') task_str = f'{self.get_task_family()}[{self.make_unique_id()}]({", ".join(repr_parts)})' return task_str def _make_representation(self, param_obj: luigi.Parameter, param_value: Any) -> str: if isinstance(param_obj, TaskInstanceParameter): return f'{param_value.get_task_family()}({param_value.make_unique_id()})' if isinstance(param_obj, ListTaskInstanceParameter): return f'[{", ".join(f"{v.get_task_family()}({v.make_unique_id()})" for v in param_value)}]' return str(param_obj.serialize(param_value)) ================================================ FILE: gokart/task_complete_check.py ================================================ from __future__ import annotations import functools from collections.abc import Callable from logging import getLogger from typing import Any logger = getLogger(__name__) def task_complete_check_wrapper(run_func: Callable[..., Any], complete_check_func: Callable[..., Any]) -> Callable[..., Any]: @functools.wraps(run_func) def wrapper(*args, **kwargs): if complete_check_func(): logger.warning(f'{run_func.__name__} is skipped because the task is already completed.') return return run_func(*args, **kwargs) return wrapper ================================================ FILE: gokart/testing/__init__.py ================================================ __all__ = [ 'test_run', 'try_to_run_test_for_empty_data_frame', 'assert_frame_contents_equal', ] from gokart.testing.check_if_run_with_empty_data_frame import test_run, try_to_run_test_for_empty_data_frame from gokart.testing.pandas_assert import assert_frame_contents_equal ================================================ FILE: gokart/testing/check_if_run_with_empty_data_frame.py ================================================ from __future__ import annotations import logging import sys from typing import Any import luigi from luigi.cmdline_parser import CmdlineParser import gokart from gokart.utils import flatten test_logger = logging.getLogger(__name__) test_logger.addHandler(logging.StreamHandler()) test_logger.setLevel(logging.INFO) class test_run(gokart.TaskOnKart[Any]): pandas: luigi.BoolParameter = luigi.BoolParameter() namespace: luigi.OptionalStrParameter = luigi.OptionalStrParameter( default=None, description='When task namespace is not defined explicitly, please use "__not_user_specified".' ) class _TestStatus: def __init__(self, task: gokart.TaskOnKart[Any]) -> None: self.namespace = task.task_namespace self.name = type(task).__name__ self.task_id = task.make_unique_id() self.status = 'OK' self.message: Exception | None = None def format(self) -> str: s = f'status={self.status}; namespace={self.namespace}; name={self.name}; id={self.task_id};' if self.message: s += f' message={type(self.message)}: {", ".join(map(str, self.message.args))}' return s def fail(self) -> bool: return self.status != 'OK' def _get_all_tasks(task: gokart.TaskOnKart[Any]) -> list[gokart.TaskOnKart[Any]]: result = [task] for o in flatten(task.requires()): result.extend(_get_all_tasks(o)) return result def _run_with_test_status(task: gokart.TaskOnKart[Any]) -> _TestStatus: test_message = _TestStatus(task) try: task.run() except Exception as e: test_message.status = 'NG' test_message.message = e return test_message def _test_run_with_empty_data_frame(cmdline_args: list[str], test_run_params: test_run) -> None: from unittest.mock import patch try: gokart.run(cmdline_args=cmdline_args) except SystemExit as e: assert e.code == 0, f'original workflow does not run properly. It exited with error code {e}.' with CmdlineParser.global_instance(cmdline_args) as cp: all_tasks = _get_all_tasks(cp.get_task_obj()) if test_run_params.namespace is not None: all_tasks = [t for t in all_tasks if t.task_namespace == test_run_params.namespace] with patch('gokart.TaskOnKart.dump', new=lambda *args, **kwargs: None): test_status_list = [_run_with_test_status(t) for t in all_tasks] test_logger.info('gokart test results:\n' + '\n'.join(s.format() for s in test_status_list)) if any(s.fail() for s in test_status_list): sys.exit(1) def try_to_run_test_for_empty_data_frame(cmdline_args: list[str]) -> None: with CmdlineParser.global_instance(cmdline_args): test_run_params = test_run() if test_run_params.pandas: cmdline_args = [a for a in cmdline_args if not a.startswith('--test-run-')] _test_run_with_empty_data_frame(cmdline_args=cmdline_args, test_run_params=test_run_params) sys.exit(0) ================================================ FILE: gokart/testing/pandas_assert.py ================================================ from __future__ import annotations from typing import Any import pandas as pd def assert_frame_contents_equal(actual: pd.DataFrame, expected: pd.DataFrame, **kwargs: Any) -> None: """ Assert that two DataFrames are equal. This function is mostly same as pandas.testing.assert_frame_equal, however - this fuction ignores the order of index and columns. - this function fails when duplicated index or columns are found. Parameters ---------- - actual, expected: pd.DataFrame DataFrames to be compared. - kwargs: Any Parameters passed to pandas.testing.assert_frame_equal. """ assert isinstance(actual, pd.DataFrame), 'actual is not a DataFrame' assert isinstance(expected, pd.DataFrame), 'expected is not a DataFrame' assert actual.index.is_unique, 'actual index is not unique' assert expected.index.is_unique, 'expected index is not unique' assert actual.columns.is_unique, 'actual columns is not unique' assert expected.columns.is_unique, 'expected columns is not unique' assert set(actual.columns) == set(expected.columns), 'columns are not equal' assert set(actual.index) == set(expected.index), 'indexes are not equal' expected_reindexed = expected.reindex(actual.index)[actual.columns] pd.testing.assert_frame_equal(actual, expected_reindexed, **kwargs) ================================================ FILE: gokart/tree/task_info.py ================================================ from __future__ import annotations import os from typing import Any import pandas as pd from gokart.target import make_target from gokart.task import TaskOnKart from gokart.tree.task_info_formatter import make_task_info_tree, make_tree_info, make_tree_info_table_list def make_task_info_as_tree_str(task: TaskOnKart[Any], details: bool = False, abbr: bool = True, ignore_task_names: list[str] | None = None) -> str: """ Return a string representation of the tasks, their statuses/parameters in a dependency tree format Parameters ---------- - task: TaskOnKart Root task. - details: bool Whether or not to output details. - abbr: bool Whether or not to simplify tasks information that has already appeared. - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- - tree_info : str Formatted task dependency tree. """ task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names) result = make_tree_info(task_info=task_info, indent='', last=True, details=details, abbr=abbr, visited_tasks=set()) return result def make_task_info_as_table(task: TaskOnKart[Any], ignore_task_names: list[str] | None = None) -> pd.DataFrame: """Return a table containing information about dependent tasks. Parameters ---------- - task: TaskOnKart Root task. - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- - task_info_table : pandas.DataFrame Formatted task dependency table. """ task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names) task_info_table = pd.DataFrame(make_tree_info_table_list(task_info=task_info, visited_tasks=set())) return task_info_table def dump_task_info_table(task: TaskOnKart[Any], task_info_dump_path: str, ignore_task_names: list[str] | None = None) -> None: """Dump a table containing information about dependent tasks. Parameters ---------- - task: TaskOnKart Root task. - task_info_dump_path: str Output target file path. Path destination can be `local`, `S3`, or `GCS`. File extension can be any type that gokart file processor accepts, including `csv`, `pickle`, or `txt`. See `TaskOnKart.make_target module ` for details. - ignore_task_names: list[str] | None List of task names to ignore. Returns ------- None """ task_info_table = make_task_info_as_table(task=task, ignore_task_names=ignore_task_names) unique_id = task.make_unique_id() task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id) task_info_target.dump(obj=task_info_table, lock_at_dump=False) def dump_task_info_tree(task: TaskOnKart[Any], task_info_dump_path: str, ignore_task_names: list[str] | None = None, use_unique_id: bool = True) -> None: """Dump the task info tree object (TaskInfo) to a pickle file. Parameters ---------- - task: TaskOnKart Root task. - task_info_dump_path: str Output target file path. Path destination can be `local`, `S3`, or `GCS`. File extension must be '.pkl'. - ignore_task_names: list[str] | None List of task names to ignore. - use_unique_id: bool = True Whether to use unique id to dump target file. Default is True. Returns ------- None """ extension = os.path.splitext(task_info_dump_path)[1] assert extension == '.pkl', f'File extention must be `.pkl`, not `{extension}`.' task_info_tree = make_task_info_tree(task, ignore_task_names=ignore_task_names) unique_id = task.make_unique_id() if use_unique_id else None task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id) task_info_target.dump(obj=task_info_tree, lock_at_dump=False) ================================================ FILE: gokart/tree/task_info_formatter.py ================================================ from __future__ import annotations import typing import warnings from dataclasses import dataclass from typing import Any, NamedTuple from gokart.task import TaskOnKart from gokart.utils import FlattenableItems, flatten @dataclass class TaskInfo: name: str unique_id: str output_paths: list[str] params: dict[str, Any] processing_time: str is_complete: str task_log: dict[str, Any] requires: FlattenableItems[RequiredTask] children_task_infos: list[TaskInfo] def get_task_id(self): return f'{self.name}_{self.unique_id}' def get_task_title(self): return f'({self.is_complete}) {self.name}[{self.unique_id}]' def get_task_detail(self): return f'(parameter={self.params}, output={self.output_paths}, time={self.processing_time}, task_log={self.task_log})' def task_info_dict(self): return dict( name=self.name, unique_id=self.unique_id, output_paths=self.output_paths, params=self.params, processing_time=self.processing_time, is_complete=self.is_complete, task_log=self.task_log, requires=self.requires, ) class RequiredTask(NamedTuple): name: str unique_id: str def _make_requires_info(requires): if isinstance(requires, TaskOnKart): return RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id()) elif isinstance(requires, dict): return {key: _make_requires_info(requires=item) for key, item in requires.items()} elif isinstance(requires, typing.Iterable): return [_make_requires_info(requires=item) for item in requires] raise TypeError(f'`requires` has unexpected type {type(requires)}. Must be `TaskOnKart`, `Iterarble[TaskOnKart]`, or `Dict[str, TaskOnKart]`') def make_task_info_tree(task: TaskOnKart[Any], ignore_task_names: list[str] | None = None, cache: dict[str, TaskInfo] | None = None) -> TaskInfo: with warnings.catch_warnings(): warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete() method') is_task_complete = task.complete() name = task.__class__.__name__ unique_id = task.make_unique_id() output_paths: list[str] = [t.path() for t in flatten(task.output())] cache = {} if cache is None else cache cache_id = f'{name}_{unique_id}_{is_task_complete}' if cache_id in cache: return cache[cache_id] params = task.get_info(only_significant=True) processing_time = task.get_processing_time() if isinstance(processing_time, float): processing_time = str(processing_time) + 's' is_complete = 'COMPLETE' if is_task_complete else 'PENDING' task_log = dict(task.get_task_log()) requires = _make_requires_info(task.requires()) children = flatten(task.requires()) children_task_infos: list[TaskInfo] = [] for child in children: if ignore_task_names is None or child.__class__.__name__ not in ignore_task_names: children_task_infos.append(make_task_info_tree(child, ignore_task_names=ignore_task_names, cache=cache)) task_info = TaskInfo( name=name, unique_id=unique_id, output_paths=output_paths, params=params, processing_time=processing_time, is_complete=is_complete, task_log=task_log, requires=requires, children_task_infos=children_task_infos, ) cache[cache_id] = task_info return task_info def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, abbr: bool, visited_tasks: set[str]) -> str: result = '\n' + indent if last: result += '└─-' indent += ' ' else: result += '|--' indent += '| ' result += task_info.get_task_title() if abbr: task_id = task_info.get_task_id() if task_id not in visited_tasks: visited_tasks.add(task_id) else: result += f'\n{indent}└─- ...' return result if details: result += task_info.get_task_detail() children = task_info.children_task_infos for index, child in enumerate(children): result += make_tree_info(child, indent, (index + 1) == len(children), details=details, abbr=abbr, visited_tasks=visited_tasks) return result def make_tree_info_table_list(task_info: TaskInfo, visited_tasks: set[str]) -> list[dict[str, typing.Any]]: task_id = task_info.get_task_id() if task_id in visited_tasks: return [] visited_tasks.add(task_id) result = [task_info.task_info_dict()] children = task_info.children_task_infos for child in children: result += make_tree_info_table_list(task_info=child, visited_tasks=visited_tasks) return result ================================================ FILE: gokart/utils.py ================================================ from __future__ import annotations import os from collections.abc import Callable, Iterable from io import BytesIO from typing import Any, Literal, Protocol, TypeAlias, TypeVar, get_args, get_origin import dill import luigi import pandas as pd class FileLike(Protocol): def read(self, n: int) -> bytes: ... def readline(self) -> bytes: ... def seek(self, offset: int) -> None: ... def seekable(self) -> bool: ... def add_config(file_path: str) -> None: _, ext = os.path.splitext(file_path) luigi.configuration.core.parser = ext # type: ignore assert luigi.configuration.add_config_path(file_path) T = TypeVar('T') FlattenableItems: TypeAlias = T | Iterable['FlattenableItems[T]'] | dict[str, 'FlattenableItems[T]'] def flatten(targets: FlattenableItems[T]) -> list[T]: """ Creates a flat list of all items in structured output (dicts, lists, items): .. code-block:: python >>> sorted(flatten({'a': 'foo', 'b': 'bar'})) ['bar', 'foo'] >>> sorted(flatten(['foo', ['bar', 'troll']])) ['bar', 'foo', 'troll'] >>> flatten('foo') ['foo'] >>> flatten(42) [42] This method is copied and modified from [luigi.task.flatten](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/luigi/task.py#L958) in accordance with [Apache License 2.0](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/LICENSE). """ if targets is None: return [] flat = [] if isinstance(targets, dict): for _, result in targets.items(): flat += flatten(result) return flat if isinstance(targets, str): return [targets] # type: ignore if not isinstance(targets, Iterable): return [targets] for result in targets: flat += flatten(result) return flat K = TypeVar('K') def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]: if isinstance(items, dict): return {k: map_flattenable_items(func, v) for k, v in items.items()} if isinstance(items, tuple): return tuple(map_flattenable_items(func, i) for i in items) if isinstance(items, str): return func(items) # type: ignore if isinstance(items, Iterable): return list(map(lambda item: map_flattenable_items(func, item), items)) return func(items) def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any: """Load binary dumped by dill with pandas backward compatibility. pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle. It is unclear whether all objects dumped by dill can be loaded by pd.read_pickle, we use dill.load as a fallback. """ try: return dill.load(file) except Exception: assert file.seekable(), f'{file} is not seekable.' file.seek(0) return pd.read_pickle(file) def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars', 'polars-lazy']: """ Extract DataFrame type from TaskOnKart[T] type parameter. Examines the type parameter T of a TaskOnKart subclass to determine whether it uses pandas or polars DataFrames/LazyFrames. Args: task: A TaskOnKart instance or class Returns: 'pandas', 'polars', or 'polars-lazy' (defaults to 'pandas' if type cannot be determined) Examples: >>> class MyTask(TaskOnKart[pd.DataFrame]): pass >>> get_dataframe_type_from_task(MyTask()) 'pandas' >>> class MyPolarsTask(TaskOnKart[pl.DataFrame]): pass >>> get_dataframe_type_from_task(MyPolarsTask()) 'polars' """ task_class = task if isinstance(task, type) else task.__class__ # Walk the MRO to find TaskOnKart[...] even when defined on a parent class mro = task_class.mro() if hasattr(task_class, 'mro') else [task_class] for cls in mro: for base in getattr(cls, '__orig_bases__', ()): origin = get_origin(base) if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart': args = get_args(base) if not args: continue df_type = args[0] module = getattr(df_type, '__module__', '') # Check module name to determine DataFrame type if 'polars' in module: name = getattr(df_type, '__name__', '') if name == 'LazyFrame': return 'polars-lazy' return 'polars' elif 'pandas' in module: return 'pandas' return 'pandas' # Default to pandas for backward compatibility ================================================ FILE: gokart/worker.py ================================================ # # Copyright 2012-2015 Spotify AB # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # """ The worker communicates with the scheduler and does two things: 1. Sends all tasks that has to be run 2. Gets tasks from the scheduler that should be run When running in local mode, the worker talks directly to a :py:class:`~luigi.scheduler.Scheduler` instance. When you run a central server, the worker will talk to the scheduler using a :py:class:`~luigi.rpc.RemoteScheduler` instance. Everything in this module is private to luigi and may change in incompatible ways between versions. The exception is the exception types and the :py:class:`worker` config class. """ from __future__ import annotations import collections import collections.abc import contextlib import datetime import functools import getpass import importlib import json import logging import multiprocessing import os import queue as Queue import random import signal import socket import subprocess import sys import threading import time import traceback from collections.abc import Generator from typing import Any, Literal, cast import luigi import luigi.scheduler import luigi.worker from luigi import notifications from luigi.event import Event from luigi.scheduler import DISABLED, DONE, FAILED, PENDING, UNKNOWN, WORKER_STATE_ACTIVE, WORKER_STATE_DISABLED, RetryPolicy, Scheduler from luigi.target import Target from luigi.task import DynamicRequirements, Task, flatten from luigi.task_register import TaskClassException, load_task from luigi.task_status import RUNNING from gokart.parameter import ExplicitBoolParameter logger = logging.getLogger(__name__) # Use fork context instead of the default (spawn on macOS), which ensures compatibility with gokart's multiprocessing requirements. _fork_context = multiprocessing.get_context('fork') _ForkProcess = _fork_context.Process # Prevent fork() from being called during a C-level getaddrinfo() which uses a process-global mutex, # that may not be unlocked in child process, resulting in the process being locked indefinitely. fork_lock = threading.Lock() # Why we assert on _WAIT_INTERVAL_EPS: # multiprocessing.Queue.get() is undefined for timeout=0 it seems: # https://docs.python.org/3.4/library/multiprocessing.html#multiprocessing.Queue.get. # I also tried with really low epsilon, but then ran into the same issue where # the test case "test_external_dependency_worker_is_patient" got stuck. So I # unscientifically just set the final value to a floating point number that # "worked for me". _WAIT_INTERVAL_EPS = 0.00001 def _is_external(task: Task) -> bool: return task.run is None or task.run == NotImplemented def _get_retry_policy_dict(task: Task) -> dict[str, Any]: return RetryPolicy(task.retry_count, task.disable_hard_timeout, task.disable_window)._asdict() # type: ignore GetWorkResponse = collections.namedtuple( 'GetWorkResponse', ( 'task_id', 'running_tasks', 'n_pending_tasks', 'n_unique_pending', 'n_pending_last_scheduled', 'worker_state', ), ) class TaskProcess(_ForkProcess): # type: ignore[valid-type, misc] """Wrap all task execution in this class. Mainly for convenience since this is run in a separate process.""" # mapping of status_reporter attributes to task attributes that are added to tasks # before they actually run, and removed afterwards forward_reporter_attributes = { 'update_tracking_url': 'set_tracking_url', 'update_status_message': 'set_status_message', 'update_progress_percentage': 'set_progress_percentage', 'decrease_running_resources': 'decrease_running_resources', 'scheduler_messages': 'scheduler_messages', } def __init__( self, task: luigi.Task, worker_id: str, result_queue: multiprocessing.Queue[Any], status_reporter: luigi.worker.TaskStatusReporter, use_multiprocessing: bool = False, worker_timeout: int = 0, check_unfulfilled_deps: bool = True, check_complete_on_run: bool = False, task_completion_cache: dict[str, Any] | None = None, task_completion_check_at_run: bool = True, ) -> None: super().__init__() self.task = task self.worker_id = worker_id self.result_queue = result_queue self.status_reporter = status_reporter self.worker_timeout = task.worker_timeout if task.worker_timeout is not None else worker_timeout self.timeout_time = time.time() + self.worker_timeout if self.worker_timeout else None self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None self.check_unfulfilled_deps = check_unfulfilled_deps self.check_complete_on_run = check_complete_on_run self.task_completion_cache = task_completion_cache self.task_completion_check_at_run = task_completion_check_at_run # completeness check using the cache self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache) def _run_task(self) -> collections.abc.Generator[Any, Any, Any] | None: if self.task_completion_check_at_run and self.check_complete(self.task): logger.warning(f'{self.task} is skipped because the task is already completed.') return None return cast(collections.abc.Generator[Any, Any, Any] | None, self.task.run()) def _run_get_new_deps(self) -> list[tuple[str, str, dict[str, str]]] | None: task_gen = self._run_task() if not isinstance(task_gen, collections.abc.Generator): return None next_send = None while True: try: if next_send is None: requires = next(task_gen) else: requires = task_gen.send(next_send) except StopIteration: return None # if requires is not a DynamicRequirements, create one to use its default behavior if not isinstance(requires, DynamicRequirements): requires = DynamicRequirements(requires) if not requires.complete(self.check_complete): # not all requirements are complete, return them which adds them to the tree new_deps = [(t.task_module, t.task_family, t.to_str_params()) for t in requires.flat_requirements] return new_deps # get the next generator result next_send = requires.paths def run(self) -> None: logger.info('[pid %s] Worker %s running %s', os.getpid(), self.worker_id, self.task) if self.use_multiprocessing: # Need to have different random seeds if running in separate processes processID = os.getpid() currentTime = time.time() random.seed(processID * currentTime) status: str | None = FAILED expl = '' missing: list[str] = [] new_deps: list[tuple[str, str, dict[str, str]]] | None = [] try: # Verify that all the tasks are fulfilled! For external tasks we # don't care about unfulfilled dependencies, because we are just # checking completeness of self.task so outputs of dependencies are # irrelevant. if self.check_unfulfilled_deps and not _is_external(self.task): missing = [] for dep in self.task.deps(): if not self.check_complete(dep): nonexistent_outputs = [output for output in flatten(dep.output()) if not output.exists()] if nonexistent_outputs: missing.append(f'{dep.task_id} ({", ".join(map(str, nonexistent_outputs))})') else: missing.append(dep.task_id) if missing: deps = 'dependency' if len(missing) == 1 else 'dependencies' raise RuntimeError('Unfulfilled {} at run time: {}'.format(deps, ', '.join(missing))) self.task.trigger_event(Event.START, self.task) t0 = time.time() status = None if _is_external(self.task): # External task if self.check_complete(self.task): status = DONE else: status = FAILED expl = 'Task is an external data dependency and data does not exist (yet?).' else: with self._forward_attributes(): new_deps = self._run_get_new_deps() if not new_deps: if not self.check_complete_on_run: # update the cache if self.task_completion_cache is not None: self.task_completion_cache[self.task.task_id] = True status = DONE elif self.check_complete(self.task): status = DONE else: raise luigi.worker.TaskException('Task finished running, but complete() is still returning false.') else: status = PENDING if new_deps: logger.info('[pid %s] Worker %s new requirements %s', os.getpid(), self.worker_id, self.task) elif status == DONE: self.task.trigger_event(Event.PROCESSING_TIME, self.task, time.time() - t0) expl = self.task.on_success() logger.info('[pid %s] Worker %s done %s', os.getpid(), self.worker_id, self.task) self.task.trigger_event(Event.SUCCESS, self.task) except KeyboardInterrupt: raise except BaseException as ex: status = FAILED expl = self._handle_run_exception(ex) finally: self.result_queue.put((self.task.task_id, status, expl, missing, new_deps)) def _handle_run_exception(self, ex: BaseException) -> str: logger.exception('[pid %s] Worker %s failed %s', os.getpid(), self.worker_id, self.task) self.task.trigger_event(Event.FAILURE, self.task, ex) return cast(str, self.task.on_failure(ex)) def _recursive_terminate(self) -> None: import psutil try: parent = psutil.Process(self.pid) children = parent.children(recursive=True) # terminate parent. Give it a chance to clean up super().terminate() parent.wait() # terminate children for child in children: try: child.terminate() except psutil.NoSuchProcess: continue except psutil.NoSuchProcess: return def terminate(self) -> None: """Terminate this process and its subprocesses.""" # default terminate() doesn't cleanup child processes, it orphans them. try: return self._recursive_terminate() except ImportError: super().terminate() @contextlib.contextmanager def _forward_attributes(self): # forward configured attributes to the task for reporter_attr, task_attr in self.forward_reporter_attributes.items(): setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr)) try: yield self finally: # reset attributes again for _, task_attr in self.forward_reporter_attributes.items(): setattr(self.task, task_attr, None) # This code and the task_process_context config key currently feels a bit ad-hoc. # Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897 class ContextManagedTaskProcess(TaskProcess): def __init__(self, context: Any, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.context = context def run(self) -> None: if self.context: logger.debug('Importing module and instantiating ' + self.context) module_path, class_name = self.context.rsplit('.', 1) module = importlib.import_module(module_path) cls = getattr(module, class_name) with cls(self): super().run() else: super().run() class gokart_worker(luigi.Config): """Configuration for the gokart worker. You can set these options of section [gokart_worker] in your luigi.cfg file. NOTE: use snake_case for this class to match the luigi.Config convention. """ id: luigi.StrParameter = luigi.StrParameter(default='', description='Override the auto-generated worker_id') ping_interval: luigi.FloatParameter = luigi.FloatParameter( default=1.0, config_path=dict(section='core', name='worker-ping-interval'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 ) keep_alive: luigi.BoolParameter = luigi.BoolParameter( default=False, config_path=dict(section='core', name='worker-keep-alive'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 ) count_uniques: luigi.BoolParameter = luigi.BoolParameter( default=False, config_path=dict(section='core', name='worker-count-uniques'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 description='worker-count-uniques means that we will keep a worker alive only if it has a unique pending task, as well as having keep-alive true', ) count_last_scheduled: luigi.BoolParameter = luigi.BoolParameter( default=False, description='Keep a worker alive only if there are pending tasks which it was the last to schedule.' ) wait_interval: luigi.FloatParameter = luigi.FloatParameter( default=1.0, config_path=dict(section='core', name='worker-wait-interval'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 ) wait_jitter: luigi.FloatParameter = luigi.FloatParameter(default=5.0) max_keep_alive_idle_duration: luigi.TimeDeltaParameter = luigi.TimeDeltaParameter(default=datetime.timedelta(0)) max_reschedules: luigi.IntParameter = luigi.IntParameter( default=1, config_path=dict(section='core', name='worker-max-reschedules'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 ) timeout: luigi.IntParameter = luigi.IntParameter( default=0, config_path=dict(section='core', name='worker-timeout'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 ) task_limit: luigi.OptionalIntParameter = luigi.OptionalIntParameter( default=None, # type: ignore[arg-type] # OptionalIntParameter.__init__ inherits IntParameter's signature config_path=dict(section='core', name='worker-task-limit'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 ) retry_external_tasks: luigi.BoolParameter = luigi.BoolParameter( default=False, config_path=dict(section='core', name='retry-external-tasks'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403 description='If true, incomplete external tasks will be retested for completion while Luigi is running.', ) send_failure_email: luigi.BoolParameter = luigi.BoolParameter(default=True, description='If true, send e-mails directly from the workeron failure') no_install_shutdown_handler: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If true, the SIGUSR1 shutdown handler willNOT be install on the worker' ) check_unfulfilled_deps: luigi.BoolParameter = luigi.BoolParameter( default=True, description='If true, check for completeness of dependencies before running a task' ) check_complete_on_run: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If true, only mark tasks as done after running if they are complete. ' 'Regardless of this setting, the worker will always check if external ' 'tasks are complete before marking them as done.', ) force_multiprocessing: luigi.BoolParameter = luigi.BoolParameter(default=False, description='If true, use multiprocessing also when running with 1 worker') task_process_context: luigi.OptionalStrParameter = luigi.OptionalStrParameter( default=None, description='If set to a fully qualified class name, the class will ' 'be instantiated with a TaskProcess as its constructor parameter and ' 'applied as a context manager around its run() call, so this can be ' 'used for obtaining high level customizable monitoring or logging of ' 'each individual Task run.', ) cache_task_completion: luigi.BoolParameter = luigi.BoolParameter( default=False, description='If true, cache the response of successful completion checks ' 'of tasks assigned to a worker. This can especially speed up tasks with ' 'dynamic dependencies but assumes that the completion status does not change ' 'after it was true the first time.', ) task_completion_check_at_run: luigi.BoolParameter = ExplicitBoolParameter( default=True, description='If true, tasks completeness will be re-checked just before the run, in case they are finished elsewhere.' ) class Worker: """ Worker object communicates with a scheduler. Simple class that talks to a scheduler and: * tells the scheduler what it has to do + its dependencies * asks for stuff to do (pulls it in a loop and runs it) """ def __init__( self, scheduler: Scheduler | None = None, worker_id: str | None = None, worker_processes: int = 1, assistant: bool = False, config: gokart_worker | None = None, ) -> None: if scheduler is None: scheduler = Scheduler() self.worker_processes = int(worker_processes) self._worker_info = self._generate_worker_info() if config is None: self._config = gokart_worker() else: self._config = config worker_id = worker_id or self._config.id or self._generate_worker_id(self._worker_info) assert self._config.wait_interval >= _WAIT_INTERVAL_EPS, '[worker] wait_interval must be positive' assert self._config.wait_jitter >= 0.0, '[worker] wait_jitter must be equal or greater than zero' self._id = worker_id self._scheduler = scheduler self._assistant = assistant self._stop_requesting_work = False self.host = socket.gethostname() self._scheduled_tasks: dict[str, Task] = {} self._suspended_tasks: dict[str, Task] = {} self._batch_running_tasks: dict[str, Any] = {} self._batch_families_sent: set[str] = set() self._first_task = None self.add_succeeded = True self.run_succeeded = True self.unfulfilled_counts: dict[str, int] = collections.defaultdict(int) # note that ``signal.signal(signal.SIGUSR1, fn)`` only works inside the main execution thread, which is why we # provide the ability to conditionally install the hook. if not self._config.no_install_shutdown_handler: try: signal.signal(signal.SIGUSR1, self.handle_interrupt) signal.siginterrupt(signal.SIGUSR1, False) except AttributeError: pass # Keep info about what tasks are running (could be in other processes) self._task_result_queue: multiprocessing.Queue[Any] = _fork_context.Queue() self._running_tasks: dict[str, TaskProcess] = {} self._idle_since: datetime.datetime | None = None # mp-safe dictionary for caching completation checks across task processes self._task_completion_cache = None if self._config.cache_task_completion: self._task_completion_cache = _fork_context.Manager().dict() # Stuff for execution_summary self._add_task_history: list[Any] = [] self._get_work_response_history: list[Any] = [] def _add_task(self, *args, **kwargs): """ Call ``self._scheduler.add_task``, but store the values too so we can implement :py:func:`luigi.execution_summary.summary`. """ task_id = kwargs['task_id'] status = kwargs['status'] runnable = kwargs['runnable'] task = self._scheduled_tasks.get(task_id) if task: self._add_task_history.append((task, status, runnable)) kwargs['owners'] = task._owner_list() if task_id in self._batch_running_tasks: for batch_task in self._batch_running_tasks.pop(task_id): self._add_task_history.append((batch_task, status, True)) if task and kwargs.get('params'): kwargs['param_visibilities'] = task._get_param_visibilities() self._scheduler.add_task(*args, **kwargs) logger.info('Informed scheduler that task %s has status %s', task_id, status) def __enter__(self) -> Worker: """ Start the KeepAliveThread. """ self._keep_alive_thread = luigi.worker.KeepAliveThread(self._scheduler, self._id, self._config.ping_interval, self._handle_rpc_message) self._keep_alive_thread.daemon = True self._keep_alive_thread.start() return self def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: """ Stop the KeepAliveThread and kill still running tasks. """ self._keep_alive_thread.stop() self._keep_alive_thread.join() for task in self._running_tasks.values(): if task.is_alive(): task.terminate() self._task_result_queue.close() return False # Don't suppress exception def _generate_worker_info(self) -> list[tuple[str, Any]]: # Generate as much info as possible about the worker # Some of these calls might not be available on all OS's args = [('salt', f'{random.randrange(0, 10_000_000_000):09d}'), ('workers', self.worker_processes)] try: args += [('host', socket.gethostname())] except BaseException: pass try: args += [('username', getpass.getuser())] except BaseException: pass try: args += [('pid', os.getpid())] except BaseException: pass try: sudo_user = os.getenv('SUDO_USER') if sudo_user: args.append(('sudo_user', sudo_user)) except BaseException: pass return args def _generate_worker_id(self, worker_info: list[Any]) -> str: worker_info_str = ', '.join([f'{k}={v}' for k, v in worker_info]) return f'Worker({worker_info_str})' def _validate_task(self, task: Task) -> None: if not isinstance(task, Task): raise luigi.worker.TaskException(f'Can not schedule non-task {task}') if not task.initialized(): # we can't get the repr of it since it's not initialized... raise luigi.worker.TaskException( f'Task of class {task.__class__.__name__} not initialized. Did you override __init__ and forget to call super(...).__init__?' ) def _log_complete_error(self, task: Task, tb: str) -> None: log_msg = f'Will not run {task} or any dependencies due to error in complete() method:\n{tb}' logger.warning(log_msg) def _log_dependency_error(self, task: Task, tb: str) -> None: log_msg = f'Will not run {task} or any dependencies due to error in deps() method:\n{tb}' logger.warning(log_msg) def _log_unexpected_error(self, task: Task) -> None: logger.exception('Luigi unexpected framework error while scheduling %s', task) # needs to be called from within except clause def _announce_scheduling_failure(self, task: Task, expl: Any) -> None: try: self._scheduler.announce_scheduling_failure( worker=self._id, task_name=str(task), family=task.task_family, params=task.to_str_params(only_significant=True), expl=expl, owners=task._owner_list(), ) except Exception: formatted_traceback = traceback.format_exc() self._email_unexpected_error(task, formatted_traceback) raise def _email_complete_error(self, task: Task, formatted_traceback: str) -> None: self._announce_scheduling_failure(task, formatted_traceback) if self._config.send_failure_email: self._email_error( task, formatted_traceback, subject='Luigi: {task} failed scheduling. Host: {host}', headline='Will not run {task} or any dependencies due to error in complete() method', ) def _email_dependency_error(self, task: Task, formatted_traceback: str) -> None: self._announce_scheduling_failure(task, formatted_traceback) if self._config.send_failure_email: self._email_error( task, formatted_traceback, subject='Luigi: {task} failed scheduling. Host: {host}', headline='Will not run {task} or any dependencies due to error in deps() method', ) def _email_unexpected_error(self, task: Task, formatted_traceback: str) -> None: # this sends even if failure e-mails are disabled, as they may indicate # a more severe failure that may not reach other alerting methods such # as scheduler batch notification self._email_error( task, formatted_traceback, subject='Luigi: Framework error while scheduling {task}. Host: {host}', headline='Luigi framework error', ) def _email_task_failure(self, task: Task, formatted_traceback: str) -> None: if self._config.send_failure_email: self._email_error( task, formatted_traceback, subject='Luigi: {task} FAILED. Host: {host}', headline='A task failed when running. Most likely run() raised an exception.', ) def _email_error(self, task: Task, formatted_traceback: str, subject: str, headline: str) -> None: formatted_subject = subject.format(task=task, host=self.host) formatted_headline = headline.format(task=task, host=self.host) command = subprocess.list2cmdline(sys.argv) message = notifications.format_task_error(formatted_headline, task, command, formatted_traceback) notifications.send_error_email(formatted_subject, message, task.owner_email) def _handle_task_load_error(self, exception: Exception, task_ids: list[str]) -> None: msg = 'Cannot find task(s) sent by scheduler: {}'.format(','.join(task_ids)) logger.exception(msg) subject = f'Luigi: {msg}' error_message = notifications.wrap_traceback(exception) for task_id in task_ids: self._add_task( worker=self._id, task_id=task_id, status=FAILED, runnable=False, expl=error_message, ) notifications.send_error_email(subject, error_message) def add(self, task: Task, multiprocess: bool = False, processes: int = 0) -> bool: """ Add a Task for the worker to check and possibly schedule and run. Returns True if task and its dependencies were successfully scheduled or completed before. """ if self._first_task is None and hasattr(task, 'task_id'): self._first_task = task.task_id self.add_succeeded = True if multiprocess: queue: Any = _fork_context.Manager().Queue() pool: Any = _fork_context.Pool(processes=processes if processes > 0 else None) else: queue = luigi.worker.DequeQueue() pool = luigi.worker.SingleProcessPool() self._validate_task(task) pool.apply_async(luigi.worker.check_complete, [task, queue, self._task_completion_cache]) # we track queue size ourselves because len(queue) won't work for multiprocessing queue_size = 1 try: seen = {task.task_id} while queue_size: current = queue.get() queue_size -= 1 item, is_complete = current for next in self._add(item, is_complete): if next.task_id not in seen: self._validate_task(next) seen.add(next.task_id) pool.apply_async(luigi.worker.check_complete, [next, queue, self._task_completion_cache]) queue_size += 1 except (KeyboardInterrupt, luigi.worker.TaskException): raise except Exception as ex: self.add_succeeded = False formatted_traceback = traceback.format_exc() self._log_unexpected_error(task) task.trigger_event(Event.BROKEN_TASK, task, ex) self._email_unexpected_error(task, formatted_traceback) raise finally: pool.close() pool.join() return self.add_succeeded def _add_task_batcher(self, task: Task) -> None: family = task.task_family if family not in self._batch_families_sent: task_class = type(task) batch_param_names = task_class.batch_param_names() if batch_param_names: self._scheduler.add_task_batcher( worker=self._id, task_family=family, batched_args=batch_param_names, max_batch_size=task.max_batch_size, ) self._batch_families_sent.add(family) def _add(self, task: Task, is_complete: bool) -> Generator[Task, None, None]: if self._config.task_limit is not None and len(self._scheduled_tasks) >= self._config.task_limit: logger.warning('Will not run %s or any dependencies due to exceeded task-limit of %d', task, self._config.task_limit) deps = None status = UNKNOWN runnable = False else: formatted_traceback = None try: self._check_complete_value(is_complete) except KeyboardInterrupt: raise except luigi.worker.AsyncCompletionException as ex: formatted_traceback = ex.trace except BaseException: formatted_traceback = traceback.format_exc() if formatted_traceback is not None: self.add_succeeded = False self._log_complete_error(task, formatted_traceback) task.trigger_event(Event.DEPENDENCY_MISSING, task) self._email_complete_error(task, formatted_traceback) deps = None status = UNKNOWN runnable = False elif is_complete: deps = None status = DONE runnable = False task.trigger_event(Event.DEPENDENCY_PRESENT, task) elif _is_external(task): deps = None status = PENDING runnable = self._config.retry_external_tasks task.trigger_event(Event.DEPENDENCY_MISSING, task) logger.warning('Data for %s does not exist (yet?). The task is an external data dependency, so it cannot be run from this luigi process.', task) else: try: deps = task.deps() self._add_task_batcher(task) except Exception as ex: formatted_traceback = traceback.format_exc() self.add_succeeded = False self._log_dependency_error(task, formatted_traceback) task.trigger_event(Event.BROKEN_TASK, task, ex) self._email_dependency_error(task, formatted_traceback) deps = None status = UNKNOWN runnable = False else: status = PENDING runnable = True if task.disabled: status = DISABLED if deps: for d in deps: self._validate_dependency(d) task.trigger_event(Event.DEPENDENCY_DISCOVERED, task, d) yield d # return additional tasks to add deps = [d.task_id for d in deps] self._scheduled_tasks[task.task_id] = task self._add_task( worker=self._id, task_id=task.task_id, status=status, deps=deps, runnable=runnable, priority=task.priority, resources=task.process_resources(), params=task.to_str_params(), family=task.task_family, module=task.task_module, batchable=task.batchable, retry_policy_dict=_get_retry_policy_dict(task), accepts_messages=task.accepts_messages, ) def _validate_dependency(self, dependency: Task) -> None: if isinstance(dependency, Target): raise Exception('requires() can not return Target objects. Wrap it in an ExternalTask class') elif not isinstance(dependency, Task): raise Exception(f'requires() must return Task objects but {dependency} is a {type(dependency)}') def _check_complete_value(self, is_complete: bool | luigi.worker.TracebackWrapper) -> None: if isinstance(is_complete, luigi.worker.TracebackWrapper): raise luigi.worker.AsyncCompletionException(is_complete.trace) if not isinstance(is_complete, bool): raise Exception(f'Return value of Task.complete() must be boolean (was {is_complete!r})') def _add_worker(self) -> None: self._worker_info.append(('first_task', self._first_task)) self._scheduler.add_worker(self._id, self._worker_info) def _log_remote_tasks(self, get_work_response: GetWorkResponse) -> None: logger.debug('Done') logger.debug('There are no more tasks to run at this time') if get_work_response.running_tasks: for r in get_work_response.running_tasks: logger.debug('%s is currently run by worker %s', r['task_id'], r['worker']) elif get_work_response.n_pending_tasks: logger.debug('There are %s pending tasks possibly being run by other workers', get_work_response.n_pending_tasks) if get_work_response.n_unique_pending: logger.debug('There are %i pending tasks unique to this worker', get_work_response.n_unique_pending) if get_work_response.n_pending_last_scheduled: logger.debug('There are %i pending tasks last scheduled by this worker', get_work_response.n_pending_last_scheduled) def _get_work_task_id(self, get_work_response: dict[str, Any]) -> str | None: if get_work_response.get('task_id') is not None: return cast(str, get_work_response['task_id']) elif 'batch_id' in get_work_response: try: task = load_task( module=get_work_response.get('task_module'), task_name=get_work_response['task_family'], params_str=get_work_response['task_params'], ) except Exception as ex: self._handle_task_load_error(ex, get_work_response['batch_task_ids']) self.run_succeeded = False return None self._scheduler.add_task( worker=self._id, task_id=task.task_id, module=get_work_response.get('task_module'), family=get_work_response['task_family'], params=task.to_str_params(), status=RUNNING, batch_id=get_work_response['batch_id'], ) return cast(str, task.task_id) else: return None def _get_work(self) -> GetWorkResponse: if self._stop_requesting_work: return GetWorkResponse(None, 0, 0, 0, 0, WORKER_STATE_DISABLED) if self.worker_processes > 0: logger.debug('Asking scheduler for work...') r = self._scheduler.get_work( worker=self._id, host=self.host, assistant=self._assistant, current_tasks=list(self._running_tasks.keys()), ) else: logger.debug('Checking if tasks are still pending') r = self._scheduler.count_pending(worker=self._id) running_tasks = r['running_tasks'] task_id = self._get_work_task_id(r) self._get_work_response_history.append( { 'task_id': task_id, 'running_tasks': running_tasks, } ) if task_id is not None and task_id not in self._scheduled_tasks: logger.info('Did not schedule %s, will load it dynamically', task_id) try: # TODO: we should obtain the module name from the server! self._scheduled_tasks[task_id] = load_task(module=r.get('task_module'), task_name=r['task_family'], params_str=r['task_params']) except TaskClassException as ex: self._handle_task_load_error(ex, [task_id]) task_id = None self.run_succeeded = False if task_id is not None and 'batch_task_ids' in r: batch_tasks = filter(None, [self._scheduled_tasks.get(batch_id) for batch_id in r['batch_task_ids']]) self._batch_running_tasks[task_id] = batch_tasks return GetWorkResponse( task_id=task_id, running_tasks=running_tasks, n_pending_tasks=r['n_pending_tasks'], n_unique_pending=r['n_unique_pending'], # TODO: For a tiny amount of time (a month?) we'll keep forwards compatibility # That is you can user a newer client than server (Sep 2016) n_pending_last_scheduled=r.get('n_pending_last_scheduled', 0), worker_state=r.get('worker_state', WORKER_STATE_ACTIVE), ) def _run_task(self, task_id: str) -> None: if task_id in self._running_tasks: logger.debug(f'Got already running task id {task_id} from scheduler, taking a break') next(self._sleeper()) return task = self._scheduled_tasks[task_id] task_process = self._create_task_process(task) self._running_tasks[task_id] = task_process if task_process.use_multiprocessing: with fork_lock: task_process.start() else: # Run in the same process task_process.run() def _create_task_process(self, task): message_queue: Any = _fork_context.Queue() if task.accepts_messages else None reporter = luigi.worker.TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue) use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1) return ContextManagedTaskProcess( self._config.task_process_context, task, self._id, self._task_result_queue, reporter, use_multiprocessing=use_multiprocessing, worker_timeout=self._config.timeout, check_unfulfilled_deps=self._config.check_unfulfilled_deps, check_complete_on_run=self._config.check_complete_on_run, task_completion_cache=self._task_completion_cache, task_completion_check_at_run=self._config.task_completion_check_at_run, ) def _purge_children(self) -> None: """ Find dead children and put a response on the result queue. :return: """ for task_id, p in self._running_tasks.items(): if not p.is_alive() and p.exitcode: error_msg = f'Task {task_id} died unexpectedly with exit code {p.exitcode}' p.task.trigger_event(Event.PROCESS_FAILURE, p.task, error_msg) elif p.timeout_time is not None and time.time() > float(p.timeout_time) and p.is_alive(): p.terminate() error_msg = f'Task {task_id} timed out after {p.worker_timeout} seconds and was terminated.' p.task.trigger_event(Event.TIMEOUT, p.task, error_msg) else: continue logger.info(error_msg) self._task_result_queue.put((task_id, FAILED, error_msg, [], [])) def _handle_next_task(self) -> None: """ We have to catch three ways a task can be "done": 1. normal execution: the task runs/fails and puts a result back on the queue, 2. new dependencies: the task yielded new deps that were not complete and will be rescheduled and dependencies added, 3. child process dies: we need to catch this separately. """ self._idle_since = None while True: self._purge_children() # Deal with subprocess failures try: task_id, status, expl, missing, new_requirements = self._task_result_queue.get(timeout=self._config.wait_interval) except Queue.Empty: return task = self._scheduled_tasks[task_id] if not task or task_id not in self._running_tasks: continue # Not a running task. Probably already removed. # Maybe it yielded something? # external task if run not implemented, retry-able if config option is enabled. external_task_retryable = _is_external(task) and self._config.retry_external_tasks if status == FAILED and not external_task_retryable: self._email_task_failure(task, expl) new_deps = [] if new_requirements: new_req = [load_task(module, name, params) for module, name, params in new_requirements] for t in new_req: self.add(t) new_deps = [t.task_id for t in new_req] self._add_task( worker=self._id, task_id=task_id, status=status, expl=json.dumps(expl), resources=task.process_resources(), runnable=None, params=task.to_str_params(), family=task.task_family, module=task.task_module, new_deps=new_deps, assistant=self._assistant, retry_policy_dict=_get_retry_policy_dict(task), ) self._running_tasks.pop(task_id) # re-add task to reschedule missing dependencies if missing: reschedule = True # keep out of infinite loops by not rescheduling too many times for task_id in missing: self.unfulfilled_counts[task_id] += 1 if self.unfulfilled_counts[task_id] > self._config.max_reschedules: reschedule = False if reschedule: self.add(task) self.run_succeeded &= (status == DONE) or (len(new_deps) > 0) return def _sleeper(self) -> Generator[None, None, None]: # TODO is exponential backoff necessary? while True: jitter = self._config.wait_jitter wait_interval = self._config.wait_interval + random.uniform(0, jitter) logger.debug('Sleeping for %f seconds', wait_interval) time.sleep(wait_interval) yield def _keep_alive(self, get_work_response: Any) -> bool: """ Returns true if a worker should stay alive given. If worker-keep-alive is not set, this will always return false. For an assistant, it will always return the value of worker-keep-alive. Otherwise, it will return true for nonzero n_pending_tasks. If worker-count-uniques is true, it will also require that one of the tasks is unique to this worker. """ if not self._config.keep_alive: return False elif self._assistant: return True elif self._config.count_last_scheduled: return cast(bool, get_work_response.n_pending_last_scheduled > 0) elif self._config.count_uniques: return cast(bool, get_work_response.n_unique_pending > 0) elif get_work_response.n_pending_tasks == 0: return False elif not self._config.max_keep_alive_idle_duration: return True elif not self._idle_since: return True else: time_to_shutdown = self._idle_since + self._config.max_keep_alive_idle_duration - datetime.datetime.now() logger.debug('[%s] %s until shutdown', self._id, time_to_shutdown) return time_to_shutdown > datetime.timedelta(0) def handle_interrupt(self, signum: int, _: Any) -> None: """ Stops the assistant from asking for more work on SIGUSR1 """ if signum == signal.SIGUSR1: self._start_phasing_out() def _start_phasing_out(self) -> None: """ Go into a mode where we dont ask for more work and quit once existing tasks are done. """ self._config.keep_alive = False self._stop_requesting_work = True def run(self) -> bool: """ Returns True if all scheduled tasks were executed successfully. """ logger.info('Running Worker with %d processes', self.worker_processes) sleeper = self._sleeper() self.run_succeeded = True self._add_worker() while True: while len(self._running_tasks) >= self.worker_processes > 0: logger.debug('%d running tasks, waiting for next task to finish', len(self._running_tasks)) self._handle_next_task() get_work_response = self._get_work() if get_work_response.worker_state == WORKER_STATE_DISABLED: self._start_phasing_out() if get_work_response.task_id is None: if not self._stop_requesting_work: self._log_remote_tasks(get_work_response) if len(self._running_tasks) == 0: self._idle_since = self._idle_since or datetime.datetime.now() if self._keep_alive(get_work_response): next(sleeper) continue else: break else: self._handle_next_task() continue # task_id is not None: logger.debug('Pending tasks: %s', get_work_response.n_pending_tasks) self._run_task(get_work_response.task_id) while len(self._running_tasks): logger.debug('Shut down Worker, %d more tasks to go', len(self._running_tasks)) self._handle_next_task() return self.run_succeeded def _handle_rpc_message(self, message: dict[str, Any]) -> None: logger.info(f'Worker {self._id} got message {message}') # the message is a dict {'name': , 'kwargs': } name = message['name'] kwargs = message['kwargs'] # find the function and check if it's callable and configured to work # as a message callback func = getattr(self, name, None) tpl = (self._id, name) if not callable(func): logger.error("Worker {} has no function '{}'".format(*tpl)) elif not getattr(func, 'is_rpc_message_callback', False): logger.error("Worker {} function '{}' is not available as rpc message callback".format(*tpl)) else: logger.info("Worker {} successfully dispatched rpc message to function '{}'".format(*tpl)) func(**kwargs) @luigi.worker.rpc_message_callback def set_worker_processes(self, n: int) -> None: # set the new value self.worker_processes = max(1, n) # tell the scheduler self._scheduler.add_worker(self._id, {'workers': self.worker_processes}) @luigi.worker.rpc_message_callback def dispatch_scheduler_message(self, task_id: str, message_id: str, content: str, **kwargs: Any) -> None: task_id = str(task_id) if task_id in self._running_tasks: task_process = self._running_tasks[task_id] if task_process.status_reporter.scheduler_messages: message = luigi.worker.SchedulerMessage(self._scheduler, task_id, message_id, content, **kwargs) task_process.status_reporter.scheduler_messages.put(message) ================================================ FILE: gokart/workspace_management.py ================================================ from __future__ import annotations import itertools import os import pathlib from logging import getLogger from typing import Any import gokart from gokart.utils import flatten logger = getLogger(__name__) def _get_all_output_file_paths(task: gokart.TaskOnKart[Any]) -> list[str]: output_paths = [t.path() for t in flatten(task.output())] children = flatten(task.requires()) output_paths.extend(itertools.chain.from_iterable([_get_all_output_file_paths(child) for child in children])) return output_paths def delete_local_unnecessary_outputs(task: gokart.TaskOnKart[Any]) -> None: task.make_unique_id() # this is required to make unique ids. all_files = {str(path) for path in pathlib.Path(task.workspace_directory).rglob('*.*')} log_files = {str(path) for path in pathlib.Path(os.path.join(task.workspace_directory, 'log')).rglob('*.*')} necessary_files = set(_get_all_output_file_paths(task)) unnecessary_files = all_files - necessary_files - log_files if len(unnecessary_files) == 0: logger.info('all files are necessary for this task.') else: logger.info(f'remove following files: {os.linesep} {os.linesep.join(unnecessary_files)}') for file in unnecessary_files: os.remove(file) ================================================ FILE: gokart/zip_client.py ================================================ from __future__ import annotations import os import shutil import zipfile from abc import abstractmethod from typing import IO def _unzip_file(fp: str | IO[bytes] | os.PathLike[str], extract_dir: str) -> None: zip_file = zipfile.ZipFile(fp) zip_file.extractall(extract_dir) zip_file.close() class ZipClient: @abstractmethod def exists(self) -> bool: pass @abstractmethod def make_archive(self) -> None: pass @abstractmethod def unpack_archive(self) -> None: pass @abstractmethod def remove(self) -> None: pass @property @abstractmethod def path(self) -> str: pass class LocalZipClient(ZipClient): def __init__(self, file_path: str, temporary_directory: str) -> None: self._file_path = file_path self._temporary_directory = temporary_directory def exists(self) -> bool: return os.path.exists(self._file_path) def make_archive(self) -> None: [base, extension] = os.path.splitext(self._file_path) shutil.make_archive(base_name=base, format=extension[1:], root_dir=self._temporary_directory) def unpack_archive(self) -> None: _unzip_file(fp=self._file_path, extract_dir=self._temporary_directory) def remove(self) -> None: shutil.rmtree(self._file_path, ignore_errors=True) @property def path(self) -> str: return self._file_path ================================================ FILE: gokart/zip_client_util.py ================================================ from __future__ import annotations from gokart.object_storage import ObjectStorage from gokart.zip_client import LocalZipClient, ZipClient def make_zip_client(file_path: str, temporary_directory: str) -> ZipClient: if ObjectStorage.if_object_storage_path(file_path): return ObjectStorage.get_zip_client(file_path=file_path, temporary_directory=temporary_directory) return LocalZipClient(file_path=file_path, temporary_directory=temporary_directory) ================================================ FILE: luigi.cfg ================================================ [core] autoload_range: false ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["hatchling", "uv-dynamic-versioning"] build-backend = "hatchling.build" [project] name = "gokart" description="Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline. [Documentation](https://gokart.readthedocs.io/en/latest/)" authors = [ {name = "M3, inc."} ] license = "MIT" readme = "README.md" requires-python = ">=3.10, <4" dependencies = [ "luigi>=3.8.0", "boto3", "slack-sdk", "pandas", "numpy", "google-auth", "pyarrow", "google-api-python-client", "APScheduler", "redis", "dill", "backoff", "typing-extensions>=4.11.0; python_version<'3.13'", ] classifiers = [ "Development Status :: 5 - Production/Stable", "License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13", "Programming Language :: Python :: 3.14", ] dynamic = ["version"] [project.optional-dependencies] polars = ["polars>=0.19.0"] [project.urls] Homepage = "https://github.com/m3dev/gokart" Repository = "https://github.com/m3dev/gokart" Documentation = "https://gokart.readthedocs.io/en/latest/" [dependency-groups] test = [ "fakeredis", "lupa", "matplotlib", "moto>=4.0", # for use mock_aws api "mypy", "polars>=0.19.0", "pytest", "pytest-cov", "pytest-xdist", "testfixtures", "toml", "types-redis", "typing-extensions>=4.11.0", ] lint = [ "ruff", "mypy", ] [tool.uv] default-groups = ['test', 'lint'] cache-keys = [ { file = "pyproject.toml" }, { git = true } ] [tool.hatch.version] source = "uv-dynamic-versioning" [tool.uv-dynamic-versioning] enable = true [tool.hatch.build.targets.sdist] include = [ "/LICENSE", "/README.md", "/examples", "/gokart", "/test", ] [tool.ruff] line-length = 160 exclude = ["venv/*", "tox/*", "examples/*"] [tool.ruff.lint] # All the rules are listed on https://docs.astral.sh/ruff/rules/ extend-select = [ "B", # bugbear "I", # isort "UP", # pyupgrade, upgrade syntax for newer versions of the language. ] # B006: Do not use mutable data structures for argument defaults. They are created during function definition time. All calls to the function reuse this one instance of that data structure, persisting changes between them. # B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value. ignore = ["B006", "B008"] [tool.ruff.format] quote-style = "single" [tool.mypy] ignore_missing_imports = true check_untyped_defs = true warn_unused_configs = true warn_redundant_casts = true no_implicit_optional = true strict_optional = true strict_equality = true warn_unused_ignores = true warn_return_any = true disallow_incomplete_defs = true disallow_any_generics = true [tool.pytest.ini_options] testpaths = ["test"] addopts = "-n auto -s -v --durations=0" ================================================ FILE: test/__init__.py ================================================ ================================================ FILE: test/config/__init__.py ================================================ from pathlib import Path from typing import Final CONFIG_DIR: Final[Path] = Path(__file__).parent.resolve() PYPROJECT_TOML: Final[Path] = CONFIG_DIR / 'pyproject.toml' PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS: Final[Path] = CONFIG_DIR / 'pyproject_disallow_missing_parameters.toml' TEST_CONFIG_INI: Final[Path] = CONFIG_DIR / 'test_config.ini' ================================================ FILE: test/config/pyproject.toml ================================================ [tool.mypy] plugins = ["gokart.mypy"] [[tool.mypy.overrides]] ignore_missing_imports = true module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"] ================================================ FILE: test/config/pyproject_disallow_missing_parameters.toml ================================================ [tool.mypy] plugins = ["gokart.mypy"] [[tool.mypy.overrides]] ignore_missing_imports = true module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"] [tool.gokart-mypy] disallow_missing_parameters = true ================================================ FILE: test/config/test_config.ini ================================================ [test_read_config._DummyTask] param = ${test_param} [test_build._DummyTask] param = ${test_param} ================================================ FILE: test/conflict_prevention_lock/__init__.py ================================================ ================================================ FILE: test/conflict_prevention_lock/test_task_lock.py ================================================ import random import unittest from typing import Any from unittest.mock import patch import gokart from gokart.conflict_prevention_lock.task_lock import RedisClient, TaskLockParams, make_task_lock_key, make_task_lock_params, make_task_lock_params_for_run class TestRedisClient(unittest.TestCase): @staticmethod def _get_randint(host, port): return random.randint(0, 100000) def test_redis_client_is_singleton(self): with patch('redis.Redis') as mock: mock.side_effect = self._get_randint redis_client_0_0 = RedisClient(host='host_0', port=123) redis_client_1 = RedisClient(host='host_1', port=123) redis_client_0_1 = RedisClient(host='host_0', port=123) self.assertNotEqual(redis_client_0_0, redis_client_1) self.assertEqual(redis_client_0_0, redis_client_0_1) self.assertEqual(redis_client_0_0.get_redis_client(), redis_client_0_1.get_redis_client()) class TestMakeRedisKey(unittest.TestCase): def test_make_redis_key(self): result = make_task_lock_key(file_path='gs://test_ll/dir/fname.pkl', unique_id='12345') self.assertEqual(result, 'fname_12345') class TestMakeRedisParams(unittest.TestCase): def test_make_task_lock_params_with_valid_host(self): result = make_task_lock_params( file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False ) expected = TaskLockParams( redis_host='0.0.0.0', redis_port=12345, redis_key='aaa_123', should_task_lock=True, redis_timeout=180, raise_task_lock_exception_on_collision=False, lock_extend_seconds=10, ) self.assertEqual(result, expected) def test_make_task_lock_params_with_no_host(self): result = make_task_lock_params( file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False ) expected = TaskLockParams( redis_host=None, redis_port=12345, redis_key='aaa_123', should_task_lock=False, redis_timeout=180, raise_task_lock_exception_on_collision=False, lock_extend_seconds=10, ) self.assertEqual(result, expected) def test_assert_when_redis_timeout_is_too_short(self): with self.assertRaises(AssertionError): make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, redis_timeout=2, ) class TestMakeTaskLockParamsForRun(unittest.TestCase): def test_make_task_lock_params_for_run(self): class _SampleDummyTask(gokart.TaskOnKart[Any]): pass task_self = _SampleDummyTask( redis_host='0.0.0.0', redis_port=12345, redis_timeout=180, ) result = make_task_lock_params_for_run(task_self=task_self, lock_extend_seconds=10) expected = TaskLockParams( redis_host='0.0.0.0', redis_port=12345, redis_timeout=180, redis_key='_SampleDummyTask_7e857f231830ca0fd6cf829d99f43961-run', should_task_lock=True, raise_task_lock_exception_on_collision=True, lock_extend_seconds=10, ) self.assertEqual(result, expected) ================================================ FILE: test/conflict_prevention_lock/test_task_lock_wrappers.py ================================================ import time import unittest from unittest.mock import MagicMock, patch import fakeredis from gokart.conflict_prevention_lock.task_lock import make_task_lock_params from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock def _sample_func_with_error(a: int, b: str) -> None: raise Exception() def _sample_long_func(a: int, b: str) -> dict[str, int | str]: time.sleep(2.7) return dict(a=a, b=b) class TestWrapDumpWithLock(unittest.TestCase): def test_no_redis(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) def test_use_redis(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) def test_if_func_is_skipped_when_cache_already_exists(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: True)(123, b='abc') mock_func.assert_not_called() def test_check_lock_extended(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, redis_timeout=2, lock_extend_seconds=1, ) with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis wrap_dump_with_lock(func=_sample_long_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') def test_lock_is_removed_after_func_is_finished(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) mock_func = MagicMock() wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[task_lock_params.redis_key] def test_lock_is_removed_after_func_is_finished_with_error(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) try: wrap_dump_with_lock(func=_sample_func_with_error, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[task_lock_params.redis_key] class TestWrapLoadWithLock(unittest.TestCase): def test_no_redis(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() resulted = wrap_load_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) def test_use_redis(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() resulted = wrap_load_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) def test_check_lock_extended(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, redis_timeout=2, lock_extend_seconds=1, ) with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis resulted = wrap_load_with_lock(func=_sample_long_func, task_lock_params=task_lock_params)(123, b='abc') expected = dict(a=123, b='abc') self.assertEqual(resulted, expected) def test_lock_is_removed_after_func_is_finished(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) mock_func = MagicMock() resulted = wrap_load_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[task_lock_params.redis_key] def test_lock_is_removed_after_func_is_finished_with_error(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) try: wrap_load_with_lock(func=_sample_func_with_error, task_lock_params=task_lock_params)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[task_lock_params.redis_key] class TestWrapRemoveWithLock(unittest.TestCase): def test_no_redis(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host=None, redis_port=None, ) mock_func = MagicMock() resulted = wrap_remove_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) def test_use_redis(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis mock_func = MagicMock() resulted = wrap_remove_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) def test_check_lock_extended(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, redis_timeout=2, lock_extend_seconds=1, ) with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.side_effect = fakeredis.FakeRedis resulted = wrap_remove_with_lock(func=_sample_long_func, task_lock_params=task_lock_params)(123, b='abc') expected = dict(a=123, b='abc') self.assertEqual(resulted, expected) def test_lock_is_removed_after_func_is_finished(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) mock_func = MagicMock() resulted = wrap_remove_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc') mock_func.assert_called_once() called_args, called_kwargs = mock_func.call_args self.assertTupleEqual(called_args, (123,)) self.assertDictEqual(called_kwargs, dict(b='abc')) self.assertEqual(resulted, mock_func()) fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[task_lock_params.redis_key] def test_lock_is_removed_after_func_is_finished_with_error(self): task_lock_params = make_task_lock_params( file_path='test_dir/test_file.pkl', unique_id='123abc', redis_host='0.0.0.0', redis_port=12345, ) server = fakeredis.FakeServer() with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock: redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port) try: wrap_remove_with_lock(func=_sample_func_with_error, task_lock_params=task_lock_params)(123, b='abc') except Exception: fake_redis = fakeredis.FakeStrictRedis(server=server) with self.assertRaises(KeyError): fake_redis[task_lock_params.redis_key] ================================================ FILE: test/file_processor/__init__.py ================================================ ================================================ FILE: test/file_processor/test_base.py ================================================ """Tests for base file processors (non-DataFrame processors).""" from __future__ import annotations import os import tempfile import unittest from collections.abc import Callable import boto3 from luigi import LocalTarget from moto import mock_aws from gokart.file_processor import PickleFileProcessor from gokart.object_storage import ObjectStorage class TestPickleFileProcessor(unittest.TestCase): def test_dump_and_load_normal_obj(self): var = 'abc' processor = PickleFileProcessor() with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.pkl' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(var, f) with local_target.open('r') as f: loaded = processor.load(f) self.assertEqual(loaded, var) def test_dump_and_load_class(self): import functools def plus1(func: Callable[..., int]) -> Callable[..., int]: @functools.wraps(func) def wrapped() -> int: ret = func() return ret + 1 return wrapped class A: def __init__(self) -> None: self.run = plus1(self.run) # type: ignore def run(self) -> int: # type: ignore return 1 obj = A() processor = PickleFileProcessor() with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.pkl' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(obj, f) with local_target.open('r') as f: loaded = processor.load(f) self.assertEqual(loaded.run(), obj.run()) @mock_aws def test_dump_and_load_with_readables3file(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') file_path = os.path.join('s3://test/', 'test.pkl') var = 'abc' processor = PickleFileProcessor() target = ObjectStorage.get_object_storage_target(file_path, processor.format()) with target.open('w') as f: processor.dump(var, f) with target.open('r') as f: loaded = processor.load(f) self.assertEqual(loaded, var) ================================================ FILE: test/file_processor/test_factory.py ================================================ """Tests for file processor factory function.""" from __future__ import annotations import unittest from gokart.file_processor import ( CsvFileProcessor, FeatherFileProcessor, GzipFileProcessor, JsonFileProcessor, NpzFileProcessor, ParquetFileProcessor, TextFileProcessor, make_file_processor, ) class TestMakeFileProcessor(unittest.TestCase): def test_make_file_processor_with_txt_extension(self): processor = make_file_processor('test.txt', store_index_in_feather=False) self.assertIsInstance(processor, TextFileProcessor) def test_make_file_processor_with_csv_extension(self): processor = make_file_processor('test.csv', store_index_in_feather=False) self.assertIsInstance(processor, CsvFileProcessor) def test_make_file_processor_with_gz_extension(self): processor = make_file_processor('test.gz', store_index_in_feather=False) self.assertIsInstance(processor, GzipFileProcessor) def test_make_file_processor_with_json_extension(self): processor = make_file_processor('test.json', store_index_in_feather=False) self.assertIsInstance(processor, JsonFileProcessor) def test_make_file_processor_with_ndjson_extension(self): processor = make_file_processor('test.ndjson', store_index_in_feather=False) self.assertIsInstance(processor, JsonFileProcessor) def test_make_file_processor_with_npz_extension(self): processor = make_file_processor('test.npz', store_index_in_feather=False) self.assertIsInstance(processor, NpzFileProcessor) def test_make_file_processor_with_parquet_extension(self): processor = make_file_processor('test.parquet', store_index_in_feather=False) self.assertIsInstance(processor, ParquetFileProcessor) def test_make_file_processor_with_feather_extension(self): processor = make_file_processor('test.feather', store_index_in_feather=True) self.assertIsInstance(processor, FeatherFileProcessor) def test_make_file_processor_with_unsupported_extension(self): with self.assertRaises(AssertionError): make_file_processor('test.unsupported', store_index_in_feather=False) ================================================ FILE: test/file_processor/test_pandas.py ================================================ """Tests for pandas-specific file processors.""" from __future__ import annotations import tempfile import unittest import pandas as pd import pytest from luigi import LocalTarget from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor class TestCsvFileProcessor(unittest.TestCase): def test_dump_csv_with_utf8(self): df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) processor = CsvFileProcessor() with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) # read with utf-8 to check if the file is dumped with utf8 loaded_df = pd.read_csv(temp_path, encoding='utf-8') pd.testing.assert_frame_equal(df, loaded_df) def test_dump_csv_with_cp932(self): df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) processor = CsvFileProcessor(encoding='cp932') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) # read with cp932 to check if the file is dumped with cp932 loaded_df = pd.read_csv(temp_path, encoding='cp932') pd.testing.assert_frame_equal(df, loaded_df) def test_load_csv_with_utf8(self): df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) processor = CsvFileProcessor() with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' df.to_csv(temp_path, encoding='utf-8', index=False) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: # read with utf-8 to check if the file is dumped with utf8 loaded_df = processor.load(f) pd.testing.assert_frame_equal(df, loaded_df) def test_load_csv_with_cp932(self): df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]}) processor = CsvFileProcessor(encoding='cp932') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' df.to_csv(temp_path, encoding='cp932', index=False) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: # read with cp932 to check if the file is dumped with cp932 loaded_df = processor.load(f) pd.testing.assert_frame_equal(df, loaded_df) class TestJsonFileProcessor: @pytest.mark.parametrize( 'orient,input_data,expected_json', [ pytest.param( None, pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for DataFrame', ), pytest.param( 'records', pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}), '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for DataFrame', ), pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'), pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'), pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'), pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'), ], ) def test_dump_and_load_json(self, orient, input_data, expected_json): processor = JsonFileProcessor(orient=orient) with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.json' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(input_data, f) with local_target.open('r') as f: loaded_df = processor.load(f) f.seek(0) loaded_json = f.read().decode('utf-8') assert loaded_json == expected_json df_input = pd.DataFrame(input_data) pd.testing.assert_frame_equal(df_input, loaded_df) class TestFeatherFileProcessor(unittest.TestCase): def test_feather_should_return_same_dataframe(self): df = pd.DataFrame({'a': [1]}) processor = FeatherFileProcessor(store_index_in_feather=True) with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) pd.testing.assert_frame_equal(df, loaded_df) def test_feather_should_save_index_name(self): df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name')) processor = FeatherFileProcessor(store_index_in_feather=True) with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) pd.testing.assert_frame_equal(df, loaded_df) def test_feather_should_raise_error_index_name_is_None(self): df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None')) processor = FeatherFileProcessor(store_index_in_feather=True) with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: with self.assertRaises(AssertionError): processor.dump(df, f) ================================================ FILE: test/file_processor/test_polars.py ================================================ """Tests for polars-specific file processors.""" from __future__ import annotations import tempfile from typing import TYPE_CHECKING import pandas as pd import pytest from luigi import LocalTarget from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor if TYPE_CHECKING: import polars as pl try: import polars as pl HAS_POLARS = True except ImportError: HAS_POLARS = False @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') class TestCsvFileProcessorWithPolars: """Tests for CsvFileProcessor with polars support""" def test_dump_polars_dataframe(self): """Test dumping a polars DataFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = CsvFileProcessor(dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) # Verify file was created and can be read by polars loaded_df = pl.read_csv(temp_path) assert loaded_df.equals(df) def test_load_polars_dataframe(self): """Test loading a CSV as polars DataFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = CsvFileProcessor(dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' df.write_csv(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_and_load_polars_roundtrip(self): """Test roundtrip dump and load with polars""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = CsvFileProcessor(dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_polars_with_pandas_load(self): """Test that polars dump can be loaded by pandas processor""" df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor_polars = CsvFileProcessor(dataframe_type='polars') processor_pandas = CsvFileProcessor(dataframe_type='pandas') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' # Dump with polars local_target = LocalTarget(path=temp_path, format=processor_polars.format()) with local_target.open('w') as f: processor_polars.dump(df_polars, f) # Load with pandas with local_target.open('r') as f: loaded_df = processor_pandas.load(f) assert isinstance(loaded_df, pd.DataFrame) # Compare values df_polars.equals(pl.from_pandas(loaded_df)) def test_polars_with_different_separator(self): """Test polars with TSV (tab-separated values)""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = CsvFileProcessor(sep='\t', dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.tsv' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_error_when_polars_not_available_for_load(self): """Test error message when polars is requested but a polars operation fails""" # This test is a bit tricky since polars IS installed in this test class # We'll just verify the processor accepts the parameter processor = CsvFileProcessor(dataframe_type='polars') assert processor._dataframe_type == 'polars' @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') class TestJsonFileProcessorWithPolars: """Tests for JsonFileProcessor with polars support""" def test_dump_polars_dataframe(self): """Test dumping a polars DataFrame to JSON""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = JsonFileProcessor(orient=None, dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.json' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) # Verify file was created and can be read by polars loaded_df = pl.read_json(temp_path) assert loaded_df.equals(df) def test_load_polars_dataframe(self): """Test loading a JSON as polars DataFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = JsonFileProcessor(orient=None, dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.json' df.write_json(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_and_load_polars_roundtrip(self): """Test roundtrip dump and load with polars""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = JsonFileProcessor(orient=None, dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.json' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_and_load_ndjson_with_polars(self): """Test ndjson (records orient) with polars""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = JsonFileProcessor(orient='records', dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.ndjson' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_polars_with_pandas_load(self): """Test that polars dump can be loaded by pandas processor""" df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor_polars = JsonFileProcessor(orient=None, dataframe_type='polars') processor_pandas = JsonFileProcessor(orient=None, dataframe_type='pandas') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.json' # Dump with polars local_target = LocalTarget(path=temp_path, format=processor_polars.format()) with local_target.open('w') as f: processor_polars.dump(df_polars, f) # Load with pandas with local_target.open('r') as f: loaded_df = processor_pandas.load(f) assert isinstance(loaded_df, pd.DataFrame) # Compare values assert list(loaded_df['a']) == [1, 2, 3] assert list(loaded_df['b']) == [4, 5, 6] @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') class TestParquetFileProcessorWithPolars: """Tests for ParquetFileProcessor with polars support""" def test_dump_polars_dataframe(self): """Test dumping a polars DataFrame to Parquet""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = ParquetFileProcessor(dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) # Verify file was created and can be read by polars loaded_df = pl.read_parquet(temp_path) assert loaded_df.equals(df) def test_load_polars_dataframe(self): """Test loading a Parquet as polars DataFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = ParquetFileProcessor(dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' df.write_parquet(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_and_load_polars_roundtrip(self): """Test roundtrip dump and load with polars""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = ParquetFileProcessor(dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_polars_with_pandas_load(self): """Test that polars dump can be loaded by pandas processor""" df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor_polars = ParquetFileProcessor(dataframe_type='polars') processor_pandas = ParquetFileProcessor(dataframe_type='pandas') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' # Dump with polars local_target = LocalTarget(path=temp_path, format=processor_polars.format()) with local_target.open('w') as f: processor_polars.dump(df_polars, f) # Load with pandas with local_target.open('r') as f: loaded_df = processor_pandas.load(f) assert isinstance(loaded_df, pd.DataFrame) df_polars.equals(pl.from_pandas(loaded_df)) def test_parquet_with_compression(self): """Test polars with parquet compression""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = ParquetFileProcessor(compression='gzip', dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') class TestFeatherFileProcessorWithPolars: """Tests for FeatherFileProcessor with polars support""" def test_dump_polars_dataframe(self): """Test dumping a polars DataFrame to Feather""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) # Verify file was created and can be read by polars loaded_df = pl.read_ipc(temp_path) assert loaded_df.equals(df) def test_load_polars_dataframe(self): """Test loading a Feather as polars DataFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' df.write_ipc(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_and_load_polars_roundtrip(self): """Test roundtrip dump and load with polars""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(df, f) with local_target.open('r') as f: loaded_df = processor.load(f) assert isinstance(loaded_df, pl.DataFrame) assert loaded_df.equals(df) def test_dump_polars_with_pandas_load(self): """Test that polars dump can be loaded by pandas processor""" df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor_polars = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars') processor_pandas = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='pandas') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' # Dump with polars local_target = LocalTarget(path=temp_path, format=processor_polars.format()) with local_target.open('w') as f: processor_polars.dump(df_polars, f) # Load with pandas with local_target.open('r') as f: loaded_df = processor_pandas.load(f) assert isinstance(loaded_df, pd.DataFrame) # Compare values df_polars.equals(pl.from_pandas(loaded_df)) @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') class TestLazyFrameSupport: """Tests for LazyFrame support in file processors using dataframe_type='polars-lazy'""" def test_csv_load_lazy(self): """Test loading CSV as LazyFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = CsvFileProcessor(dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' df.write_csv(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded = processor.load(f) assert isinstance(loaded, pl.LazyFrame) assert loaded.collect().equals(df) def test_csv_dump_lazyframe(self): """Test dumping a LazyFrame to CSV""" lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy() processor = CsvFileProcessor(dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.csv' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(lf, f) # Verify file was created and can be read loaded_df = pl.read_csv(temp_path) assert loaded_df.equals(lf.collect()) def test_parquet_load_lazy(self): """Test loading Parquet as LazyFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = ParquetFileProcessor(dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' df.write_parquet(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded = processor.load(f) assert isinstance(loaded, pl.LazyFrame) assert loaded.collect().equals(df) def test_parquet_dump_lazyframe(self): """Test dumping a LazyFrame to Parquet""" lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy() processor = ParquetFileProcessor(dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(lf, f) # Verify file was created and can be read loaded_df = pl.read_parquet(temp_path) assert loaded_df.equals(lf.collect()) def test_feather_load_lazy(self): """Test loading Feather as LazyFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' df.write_ipc(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded = processor.load(f) assert isinstance(loaded, pl.LazyFrame) assert loaded.collect().equals(df) def test_feather_dump_lazyframe(self): """Test dumping a LazyFrame to Feather""" lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy() processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.feather' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(lf, f) # Verify file was created and can be read loaded_df = pl.read_ipc(temp_path) assert loaded_df.equals(lf.collect()) def test_json_load_lazy_ndjson(self): """Test loading NDJSON as LazyFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = JsonFileProcessor(orient='records', dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.ndjson' df.write_ndjson(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded = processor.load(f) assert isinstance(loaded, pl.LazyFrame) assert loaded.collect().equals(df) def test_json_dump_lazyframe_ndjson(self): """Test dumping a LazyFrame to NDJSON""" lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy() processor = JsonFileProcessor(orient='records', dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.ndjson' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(lf, f) # Verify file was created and can be read loaded_df = pl.read_ndjson(temp_path) assert loaded_df.equals(lf.collect()) def test_json_load_lazy_standard(self): """Test loading standard JSON (orient=None) as LazyFrame""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = JsonFileProcessor(orient=None, dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.json' df.write_json(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded = processor.load(f) assert isinstance(loaded, pl.LazyFrame) assert loaded.collect().equals(df) def test_json_dump_lazyframe_standard(self): """Test dumping a LazyFrame to standard JSON (orient=None)""" lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy() processor = JsonFileProcessor(orient=None, dataframe_type='polars-lazy') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.json' local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('w') as f: processor.dump(lf, f) # Verify file was created and can be read loaded_df = pl.read_json(temp_path) assert loaded_df.equals(lf.collect()) def test_polars_returns_dataframe(self): """Test that dataframe_type='polars' returns DataFrame (not LazyFrame)""" df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}) processor = ParquetFileProcessor(dataframe_type='polars') with tempfile.TemporaryDirectory() as temp_dir: temp_path = f'{temp_dir}/temp.parquet' df.write_parquet(temp_path) local_target = LocalTarget(path=temp_path, format=processor.format()) with local_target.open('r') as f: loaded = processor.load(f) assert isinstance(loaded, pl.DataFrame) assert loaded.equals(df) ================================================ FILE: test/in_memory/test_in_memory_target.py ================================================ from datetime import datetime from time import sleep import pytest from gokart.conflict_prevention_lock.task_lock import TaskLockParams from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_in_memory_target class TestInMemoryTarget: @pytest.fixture def task_lock_params(self) -> TaskLockParams: return TaskLockParams( redis_host=None, redis_port=None, redis_timeout=None, redis_key='dummy', should_task_lock=False, raise_task_lock_exception_on_collision=False, lock_extend_seconds=0, ) @pytest.fixture def target(self, task_lock_params: TaskLockParams) -> InMemoryTarget: return make_in_memory_target(target_key='dummy_key', task_lock_params=task_lock_params) @pytest.fixture(autouse=True) def clear_repo(self) -> None: InMemoryCacheRepository().clear() def test_dump_and_load_data(self, target: InMemoryTarget) -> None: dumped = 'dummy_data' target.dump(dumped) loaded = target.load() assert loaded == dumped def test_exist(self, target: InMemoryTarget) -> None: assert not target.exists() target.dump('dummy_data') assert target.exists() def test_last_modified_time(self, target: InMemoryTarget) -> None: input = 'dummy_data' target.dump(input) time = target.last_modification_time() assert isinstance(time, datetime) sleep(0.1) another_input = 'another_data' target.dump(another_input) another_time = target.last_modification_time() assert time < another_time target.remove() with pytest.raises(ValueError): assert target.last_modification_time() ================================================ FILE: test/in_memory/test_repository.py ================================================ import time import pytest from gokart.in_memory import InMemoryCacheRepository as Repo dummy_num = 100 class TestInMemoryCacheRepository: @pytest.fixture def repo(self) -> Repo: repo = Repo() repo.clear() return repo def test_set(self, repo: Repo) -> None: repo.set_value('dummy_key', dummy_num) assert repo.size == 1 for key, value in repo.get_gen(): assert (key, value) == ('dummy_key', dummy_num) repo.set_value('another_key', 'another_value') assert repo.size == 2 def test_get(self, repo: Repo) -> None: repo.set_value('dummy_key', dummy_num) repo.set_value('another_key', 'another_value') """Raise Error when key doesn't exist.""" with pytest.raises(KeyError): repo.get_value('not_exist_key') assert repo.get_value('dummy_key') == dummy_num assert repo.get_value('another_key') == 'another_value' def test_empty(self, repo: Repo) -> None: assert repo.empty() repo.set_value('dummmy_key', dummy_num) assert not repo.empty() def test_has(self, repo: Repo) -> None: assert not repo.has('dummy_key') repo.set_value('dummy_key', dummy_num) assert repo.has('dummy_key') assert not repo.has('not_exist_key') def test_remove(self, repo: Repo) -> None: repo.set_value('dummy_key', dummy_num) with pytest.raises(AssertionError): repo.remove('not_exist_key') repo.remove('dummy_key') assert not repo.has('dummy_key') def test_last_modification_time(self, repo: Repo) -> None: repo.set_value('dummy_key', dummy_num) date1 = repo.get_last_modification_time('dummy_key') time.sleep(0.1) repo.set_value('dummy_key', dummy_num) date2 = repo.get_last_modification_time('dummy_key') assert date1 < date2 ================================================ FILE: test/slack/__init__.py ================================================ ================================================ FILE: test/slack/test_slack_api.py ================================================ import unittest from logging import getLogger from unittest import mock from unittest.mock import MagicMock from slack_sdk import WebClient from slack_sdk.web.slack_response import SlackResponse from testfixtures import LogCapture import gokart.slack logger = getLogger(__name__) def _slack_response(token, data): return SlackResponse( client=WebClient(token=token), http_verb='POST', api_url='http://localhost:3000/api.test', req_args={}, data=data, headers={}, status_code=200 ) class TestSlackAPI(unittest.TestCase): @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') def test_initialization_with_invalid_token(self, patch): def _conversations_list(params={}): return _slack_response(token='invalid', data={'ok': False, 'error': 'error_reason'}) mock_client = MagicMock() mock_client.conversations_list = MagicMock(side_effect=_conversations_list) patch.return_value = mock_client with LogCapture() as log: gokart.slack.SlackAPI(token='invalid', channel='test', to_user='test user') log.check(('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel test is not found in public channels.')) @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') def test_invalid_channel(self, patch): def _conversations_list(params={}): return _slack_response( token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}} ) mock_client = MagicMock() mock_client.conversations_list = MagicMock(side_effect=_conversations_list) patch.return_value = mock_client with LogCapture() as log: gokart.slack.SlackAPI(token='valid', channel='invalid_channel', to_user='test user') log.check( ('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel invalid_channel is not found in public channels.') ) @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') def test_send_snippet_with_invalid_token(self, patch): def _conversations_list(params={}): return _slack_response( token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}} ) def _api_call(method, data={}): assert method == 'files.upload' return {'ok': False, 'error': 'error_reason'} mock_client = MagicMock() mock_client.conversations_list = MagicMock(side_effect=_conversations_list) mock_client.api_call = MagicMock(side_effect=_api_call) patch.return_value = mock_client with LogCapture() as log: api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user') api.send_snippet(comment='test', title='title', content='content') log.check( ('gokart.slack.slack_api', 'WARNING', 'Failed to send slack notification: Error while uploading file. The error reason is "error_reason".') ) @mock.patch('gokart.slack.slack_api.slack_sdk.WebClient') def test_send(self, patch): def _conversations_list(params={}): return _slack_response( token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}} ) def _api_call(method, data={}): assert method == 'files.upload' return {'ok': False, 'error': 'error_reason'} mock_client = MagicMock() mock_client.conversations_list = MagicMock(side_effect=_conversations_list) mock_client.api_call = MagicMock(side_effect=_api_call) patch.return_value = mock_client api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user') api.send_snippet(comment='test', title='title', content='content') if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_build.py ================================================ from __future__ import annotations import io import logging import os import sys import unittest from copy import copy from typing import Any if sys.version_info >= (3, 11): from typing import assert_type else: from typing_extensions import assert_type from unittest.mock import patch import luigi import luigi.mock import gokart from gokart.build import GokartBuildError, LoggerConfig, TaskDumpConfig, TaskDumpMode, TaskDumpOutputType, process_task_info from gokart.conflict_prevention_lock.task_lock import TaskLockException class _DummyTask(gokart.TaskOnKart[str]): task_namespace = __name__ param: luigi.Parameter = luigi.Parameter() def output(self): return self.make_target('./test/dummy.pkl') def run(self): self.dump(self.param) class _DummyTaskTwoOutputs(gokart.TaskOnKart[dict[str, str]]): task_namespace = __name__ param1: luigi.Parameter = luigi.Parameter() param2: luigi.Parameter = luigi.Parameter() def output(self): return {'out1': self.make_target('./test/dummy1.pkl'), 'out2': self.make_target('./test/dummy2.pkl')} def run(self): self.dump(self.param1, 'out1') self.dump(self.param2, 'out2') class _DummyFailedTask(gokart.TaskOnKart[Any]): task_namespace = __name__ def run(self): raise RuntimeError class _ParallelRunner(gokart.TaskOnKart[str]): def requires(self): return [_DummyTask(param=str(i)) for i in range(10)] def run(self): self.dump('done') class _LoadRequires(gokart.TaskOnKart[str]): task: gokart.TaskInstanceParameter[gokart.TaskOnKart[str]] = gokart.TaskInstanceParameter() def requires(self): return self.task def run(self): s = self.load(self.task) self.dump(s) class RunTest(unittest.TestCase): def setUp(self): luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False luigi.configuration.LuigiConfigParser._instance = None self.config_paths = copy(luigi.configuration.LuigiConfigParser._config_paths) luigi.mock.MockFileSystem().clear() os.environ.clear() def tearDown(self): luigi.configuration.LuigiConfigParser._config_paths = self.config_paths os.environ.clear() luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False def test_build(self): text = 'test' output = gokart.build(_DummyTask(param=text), reset_register=False) self.assertEqual(output, text) def test_build_parallel(self): output = gokart.build(_ParallelRunner(), reset_register=False, workers=20) self.assertEqual(output, 'done') def test_read_config(self): class _DummyTask(gokart.TaskOnKart[Any]): task_namespace = 'test_read_config' param = luigi.Parameter() def run(self): self.dump(self.param) os.environ.setdefault('test_param', 'test') config_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config', 'test_config.ini') gokart.utils.add_config(config_file_path) output = gokart.build(_DummyTask(), reset_register=False) self.assertIsInstance(output, str) self.assertEqual(output, 'test') def test_build_dict_outputs(self): param_dict = { 'out1': 'test1', 'out2': 'test2', } output = gokart.build(_DummyTaskTwoOutputs(param1=param_dict['out1'], param2=param_dict['out2']), reset_register=False) assert_type(output, dict[str, str]) self.assertEqual(output, param_dict) def test_failed_task(self): with self.assertRaises(GokartBuildError): gokart.build(_DummyFailedTask(), reset_register=False, log_level=logging.CRITICAL) def test_load_requires(self): text = 'test' output = gokart.build(_LoadRequires(task=_DummyTask(param=text)), reset_register=False) self.assertEqual(output, text) def test_build_with_child_task_error(self): class CheckException(Exception): pass class FailTask(gokart.TaskOnKart[Any]): def run(self): raise CheckException() t = FailTask() with self.assertRaises(GokartBuildError) as cm: gokart.build(t, reset_register=False, log_level=logging.CRITICAL) e = cm.exception self.assertEqual(len(e.raised_exceptions), 1) self.assertIsInstance(e.raised_exceptions[t.make_unique_id()][0], CheckException) class LoggerConfigTest(unittest.TestCase): def test_logger_config(self): for level, enable_expected, disable_expected in ( (logging.INFO, logging.INFO, logging.DEBUG), (logging.DEBUG, logging.DEBUG, logging.NOTSET), (logging.CRITICAL, logging.CRITICAL, logging.ERROR), ): with self.subTest(level=level, enable_expected=enable_expected, disable_expected=disable_expected): with LoggerConfig(level) as lc: self.assertTrue(lc.logger.isEnabledFor(enable_expected)) self.assertTrue(not lc.logger.isEnabledFor(disable_expected)) class ProcessTaskInfoTest(unittest.TestCase): def test_process_task_info(self): task = _DummyTask(param='test') for config in ( TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.PRINT), TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.PRINT), ): with LoggerConfig(level=logging.INFO): from gokart.build import logger log_stream = io.StringIO() handler = logging.StreamHandler(log_stream) handler.setLevel(logging.INFO) logger.addHandler(handler) process_task_info(task, config) logger.removeHandler(handler) handler.close() self.assertIn(member=str(task.make_unique_id()), container=log_stream.getvalue()) class _FailThreeTimesAndSuccessTask(gokart.TaskOnKart[Any]): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.failed_counter = 0 def run(self): if self.failed_counter < 3: self.failed_counter += 1 raise TaskLockException() self.dump('done') class TestBuildHasLockedTaskException(unittest.TestCase): def test_build_expo_backoff_when_luigi_failed_due_to_locked_task(self): gokart.build(_FailThreeTimesAndSuccessTask(), reset_register=False) class TestBuildFailedAndSchedulingFailed(unittest.TestCase): def test_build_raises_exception_on_failed_and_scheduling_failed(self): """Test that build() raises GokartBuildError when FAILED_AND_SCHEDULING_FAILED occurs""" # Create a mock result object with FAILED_AND_SCHEDULING_FAILED status class MockResult: def __init__(self): self.status = luigi.LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED self.summary_text = 'Task failed and scheduling failed' # Mock luigi.build to return FAILED_AND_SCHEDULING_FAILED status with patch('luigi.build') as mock_luigi_build: mock_luigi_build.return_value = MockResult() # This should now raise GokartBuildError after the fix with self.assertRaises(GokartBuildError): gokart.build(_DummyTask(param='test'), reset_register=False, log_level=logging.CRITICAL) def test_build_not_raises_exception_when_success_with_retry(self): """Test that build() does not raise GokartBuildError when task succeeds with retry""" # Create a mock result object with SUCCESS_WITH_RETRY status class MockResult: def __init__(self): self.status = luigi.LuigiStatusCode.SUCCESS_WITH_RETRY self.summary_text = 'Task completed successfully after retries' # Mock _build_task to return a test value directly with patch('luigi.build') as mock_luigi_build: mock_luigi_build.return_value = MockResult() # Create a mock task that will be used by build() mock_task = _DummyTask(param='test') # This should not raise GokartBuildError # The test output will be whatever the mock returns gokart.build(mock_task, reset_register=False, return_value=False, log_level=logging.CRITICAL) def test_build_not_raises_exception_on_scheduling_failed_only(self): """Test that build() raises GokartBuildError when SCHEDULING_FAILED occurs""" # Create a mock result object with SCHEDULING_FAILED status class MockResult: def __init__(self): self.status = luigi.LuigiStatusCode.SCHEDULING_FAILED self.summary_text = 'Task scheduling failed' # Mock luigi.build to return SCHEDULING_FAILED status with patch('luigi.build') as mock_luigi_build: mock_luigi_build.return_value = MockResult() # This should raise GokartBuildError after the fix with self.assertRaises(GokartBuildError): gokart.build(_DummyTask(param='test'), reset_register=False, log_level=logging.CRITICAL) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_cache_unique_id.py ================================================ import os import unittest from typing import Any import luigi import luigi.mock import gokart class _DummyTask(gokart.TaskOnKart[Any]): def requires(self): return _DummyTaskDep() def run(self): self.dump(self.load()) class _DummyTaskDep(gokart.TaskOnKart[str]): param: luigi.Parameter = luigi.Parameter() def run(self): self.dump(self.param) class CacheUniqueIDTest(unittest.TestCase): def setUp(self): luigi.configuration.LuigiConfigParser._instance = None luigi.mock.MockFileSystem().clear() os.environ.clear() @staticmethod def _set_param(cls, attr_name: str, param: luigi.Parameter) -> None: # type: ignore # Luigi 3.8.0+ uses __set_name__ to register _attribute_name on Parameter descriptors. # When assigning after class creation (bypassing the metaclass), call it manually. param.__set_name__(cls, attr_name) setattr(cls, attr_name, param) def test_cache_unique_id_true(self): self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='original_param')) output1 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False) self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='updated_param')) output2 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False) self.assertEqual(output1, output2) def test_cache_unique_id_false(self): self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='original_param')) output1 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False) self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='updated_param')) output2 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False) self.assertNotEqual(output1, output2) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_config_params.py ================================================ import unittest from typing import Any import luigi from luigi.cmdline_parser import CmdlineParser import gokart from gokart.config_params import inherits_config_params def in_parse(cmds, deferred_computation): """function copied from luigi: https://github.com/spotify/luigi/blob/e2228418eec60b68ca09a30c878ab26413846847/test/helpers.py""" with CmdlineParser.global_instance(cmds) as cp: deferred_computation(cp.get_task_obj()) class ConfigClass(luigi.Config): param_a = luigi.Parameter(default='config a') param_b = luigi.Parameter(default='config b') param_c = luigi.Parameter(default='config c') @inherits_config_params(ConfigClass) class Inherited(gokart.TaskOnKart[Any]): param_a = luigi.Parameter() param_b = luigi.Parameter(default='overrided') @inherits_config_params(ConfigClass, parameter_alias={'param_a': 'param_d'}) class Inherited2(gokart.TaskOnKart[Any]): param_c = luigi.Parameter() param_d = luigi.Parameter() class ChildTask(Inherited): pass class ChildTaskWithNewParam(Inherited): param_new = luigi.Parameter() class ConfigClass2(luigi.Config): param_a = luigi.Parameter(default='config a from config class 2') @inherits_config_params(ConfigClass2) class ChildTaskWithNewConfig(Inherited): pass class TestInheritsConfigParam(unittest.TestCase): def test_inherited_params(self): # test fill values in_parse(['Inherited'], lambda task: self.assertEqual(task.param_a, 'config a')) # test overrided in_parse(['Inherited'], lambda task: self.assertEqual(task.param_b, 'config b')) # Command line argument takes precedence over config param in_parse(['Inherited', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg')) # Parameters which is not a member of the task will not be set with self.assertRaises(AttributeError): in_parse(['Inherited'], lambda task: task.param_c) # test parameter name alias in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_c, 'config c')) in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_d, 'config a')) def test_child_task(self): in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_a, 'config a')) in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_b, 'config b')) in_parse(['ChildTask', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg')) with self.assertRaises(AttributeError): in_parse(['ChildTask'], lambda task: task.param_c) def test_child_override(self): in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_a, 'config a from config class 2')) in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_b, 'config b')) ================================================ FILE: test/test_explicit_bool_parameter.py ================================================ import unittest from typing import Any import luigi import luigi.mock from luigi.cmdline_parser import CmdlineParser import gokart def in_parse(cmds, deferred_computation): with CmdlineParser.global_instance(cmds) as cp: deferred_computation(cp.get_task_obj()) class WithDefaultTrue(gokart.TaskOnKart[Any]): param = gokart.ExplicitBoolParameter(default=True) class WithDefaultFalse(gokart.TaskOnKart[Any]): param = gokart.ExplicitBoolParameter(default=False) class ExplicitParsing(gokart.TaskOnKart[Any]): param = gokart.ExplicitBoolParameter() def run(self): ExplicitParsing._param = self.param # type: ignore class TestExplicitBoolParameter(unittest.TestCase): def test_bool_default(self): self.assertTrue(WithDefaultTrue().param) self.assertFalse(WithDefaultFalse().param) def test_parse_param(self): in_parse(['ExplicitParsing', '--param', 'true'], lambda task: self.assertTrue(task.param)) in_parse(['ExplicitParsing', '--param', 'false'], lambda task: self.assertFalse(task.param)) in_parse(['ExplicitParsing', '--param', 'True'], lambda task: self.assertTrue(task.param)) in_parse(['ExplicitParsing', '--param', 'False'], lambda task: self.assertFalse(task.param)) def test_missing_parameter(self): with self.assertRaises(luigi.parameter.MissingParameterException): in_parse(['ExplicitParsing'], lambda: True) def test_value_error(self): with self.assertRaises(ValueError): in_parse(['ExplicitParsing', '--param', 'Foo'], lambda: True) def test_expected_one_argment_error(self): # argparse throw "expected one argument" error with self.assertRaises(SystemExit): in_parse(['ExplicitParsing', '--param'], lambda: True) ================================================ FILE: test/test_gcs_config.py ================================================ import os import unittest from unittest.mock import MagicMock, patch from gokart.gcs_config import GCSConfig class TestGCSConfig(unittest.TestCase): def test_get_gcs_client_without_gcs_credential_name(self): mock = MagicMock() os.environ['env_name'] = '' with patch('luigi.contrib.gcs.GCSClient', mock): GCSConfig(gcs_credential_name='env_name')._get_gcs_client() self.assertEqual(dict(oauth_credentials=None), mock.call_args[1]) def test_get_gcs_client_with_file_path(self): mock = MagicMock() file_path = 'test.json' os.environ['env_name'] = file_path with patch('luigi.contrib.gcs.GCSClient'): with patch('google.oauth2.service_account.Credentials.from_service_account_file', mock): with patch('os.path.isfile', return_value=True): GCSConfig(gcs_credential_name='env_name')._get_gcs_client() self.assertEqual(file_path, mock.call_args[0][0]) def test_get_gcs_client_with_json(self): mock = MagicMock() json_str = '{"test": 1}' os.environ['env_name'] = json_str with patch('luigi.contrib.gcs.GCSClient'): with patch('google.oauth2.service_account.Credentials.from_service_account_info', mock): GCSConfig(gcs_credential_name='env_name')._get_gcs_client() self.assertEqual(dict(test=1), mock.call_args[0][0]) ================================================ FILE: test/test_gcs_obj_metadata_client.py ================================================ from __future__ import annotations import datetime import unittest from typing import Any from unittest.mock import MagicMock, patch import gokart from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient from gokart.required_task_output import RequiredTaskOutput from gokart.target import TargetOnKart class _DummyTaskOnKart(gokart.TaskOnKart[str]): task_namespace = __name__ def run(self): self.dump('Dummy TaskOnKart') class TestGCSObjectMetadataClient(unittest.TestCase): def setUp(self): self.task_params: dict[str, str] = { 'param1': 'a' * 1000, 'param2': str(1000), 'param3': str({'key1': 'value1', 'key2': True, 'key3': 2}), 'param4': str([1, 2, 3, 4, 5]), 'param5': str(datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5)), 'param6': '', } self.custom_labels: dict[str, Any] = { 'created_at': datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5), 'created_by': 'hoge fuga', 'empty': True, 'try_num': 3, } self.task_params_with_conflicts = { 'empty': 'False', 'created_by': 'fuga hoge', 'param1': 'a' * 10, } def test_normalize_labels_not_empty(self): got = GCSObjectMetadataClient._normalize_labels(None) self.assertEqual(got, {}) def test_normalize_labels_has_value(self): got = GCSObjectMetadataClient._normalize_labels(self.task_params) self.assertIsInstance(got, dict) self.assertIsInstance(got, dict) self.assertIn('param1', got) self.assertIn('param2', got) self.assertIn('param3', got) self.assertIn('param4', got) self.assertIn('param5', got) self.assertIn('param6', got) def test_get_patched_obj_metadata_only_task_params(self): got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=None) self.assertIsInstance(got, dict) self.assertIn('param1', got) self.assertIn('param2', got) self.assertIn('param3', got) self.assertIn('param4', got) self.assertIn('param5', got) self.assertNotIn('param6', got) def test_get_patched_obj_metadata_only_custom_labels(self): got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=None, custom_labels=self.custom_labels) self.assertIsInstance(got, dict) self.assertIn('created_at', got) self.assertIn('created_by', got) self.assertIn('empty', got) self.assertIn('try_num', got) def test_get_patched_obj_metadata_with_both_task_params_and_custom_labels(self): got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=self.custom_labels) self.assertIsInstance(got, dict) self.assertIn('param1', got) self.assertIn('param2', got) self.assertIn('param3', got) self.assertIn('param4', got) self.assertIn('param5', got) self.assertNotIn('param6', got) self.assertIn('created_at', got) self.assertIn('created_by', got) self.assertIn('empty', got) self.assertIn('try_num', got) def test_get_patched_obj_metadata_with_exceeded_size_metadata(self): size_exceeded_task_params = { 'param1': 'a' * 5000, 'param2': 'b' * 5000, } want = { 'param1': 'a' * 5000, } got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=size_exceeded_task_params) self.assertEqual(got, want) def test_get_patched_obj_metadata_with_conflicts(self): got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params_with_conflicts, custom_labels=self.custom_labels) self.assertIsInstance(got, dict) self.assertIn('created_at', got) self.assertIn('created_by', got) self.assertIn('empty', got) self.assertIn('try_num', got) self.assertEqual(got['empty'], 'True') self.assertEqual(got['created_by'], 'hoge fuga') self.assertEqual(got['param1'], 'a' * 10) def test_get_patched_obj_metadata_with_required_task_outputs(self): got = GCSObjectMetadataClient._get_patched_obj_metadata( {}, required_task_outputs=[ RequiredTaskOutput(task_name='task1', output_path='path/to/output1'), ], ) self.assertIsInstance(got, dict) self.assertIn('__required_task_outputs', got) self.assertEqual(got['__required_task_outputs'], '[{"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}]') def test_get_patched_obj_metadata_with_nested_required_task_outputs(self): got = GCSObjectMetadataClient._get_patched_obj_metadata( {}, required_task_outputs={ 'nested_task': {'nest': RequiredTaskOutput(task_name='task1', output_path='path/to/output1')}, }, ) self.assertIsInstance(got, dict) self.assertIn('__required_task_outputs', got) self.assertEqual( got['__required_task_outputs'], '{"nested_task": {"nest": {"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}}}' ) def test_adjust_gcs_metadata_limit_size_runtime_error(self): large_labels = {} for i in range(100): large_labels[f'key_{i}'] = 'x' * 1000 GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(large_labels) class TestGokartTask(unittest.TestCase): @patch.object(_DummyTaskOnKart, '_get_output_target') def test_mock_target_on_kart(self, mock_get_output_target): mock_target = MagicMock(spec=TargetOnKart) mock_get_output_target.return_value = mock_target task = _DummyTaskOnKart() task.dump({'key': 'value'}, mock_target) mock_target.dump.assert_called_once_with( {'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None, required_task_outputs=[] ) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_info.py ================================================ import unittest from unittest.mock import patch import luigi import luigi.mock from luigi.mock import MockFileSystem, MockTarget import gokart import gokart.info from test.tree.test_task_info import _DoubleLoadSubTask, _SubTask, _Task class TestInfo(unittest.TestCase): def setUp(self) -> None: MockFileSystem().clear() luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False def tearDown(self) -> None: luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_pending(self): task = _Task(param=1, sub=_SubTask(param=2)) # check before running tree = gokart.info.make_tree_info(task) expected = r""" └─-\(PENDING\) _Task\[[a-z0-9]*\] └─-\(PENDING\) _SubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_complete(self): task = _Task(param=1, sub=_SubTask(param=2)) # check after sub task runs gokart.build(task, reset_register=False) tree = gokart.info.make_tree_info(task) expected = r""" └─-\(COMPLETE\) _Task\[[a-z0-9]*\] └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_abbreviation(self): task = _DoubleLoadSubTask( sub1=_Task(param=1, sub=_SubTask(param=2)), sub2=_Task(param=1, sub=_SubTask(param=2)), ) # check after sub task runs gokart.build(task, reset_register=False) tree = gokart.info.make_tree_info(task) expected = r""" └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\] \|--\(COMPLETE\) _Task\[[a-z0-9]*\] \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\] └─-\(COMPLETE\) _Task\[[a-z0-9]*\] └─- \.\.\.$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_not_compress(self): task = _DoubleLoadSubTask( sub1=_Task(param=1, sub=_SubTask(param=2)), sub2=_Task(param=1, sub=_SubTask(param=2)), ) # check after sub task runs gokart.build(task, reset_register=False) tree = gokart.info.make_tree_info(task, abbr=False) expected = r""" └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\] \|--\(COMPLETE\) _Task\[[a-z0-9]*\] \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\] └─-\(COMPLETE\) _Task\[[a-z0-9]*\] └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_not_compress_ignore_task(self): task = _DoubleLoadSubTask( sub1=_Task(param=1, sub=_SubTask(param=2)), sub2=_Task(param=1, sub=_SubTask(param=2)), ) # check after sub task runs gokart.build(task, reset_register=False) tree = gokart.info.make_tree_info(task, abbr=False, ignore_task_names=['_Task']) expected = r""" └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_large_data_fram_processor.py ================================================ import os import shutil import unittest import numpy as np import pandas as pd from gokart.target import LargeDataFrameProcessor from test.util import _get_temporary_directory class LargeDataFrameProcessorTest(unittest.TestCase): def setUp(self): self.temporary_directory = _get_temporary_directory() def tearDown(self): shutil.rmtree(self.temporary_directory, ignore_errors=True) def test_save_and_load(self): file_path = os.path.join(self.temporary_directory, 'test.zip') df = pd.DataFrame(dict(data=np.random.uniform(0, 1, size=int(1e6)))) processor = LargeDataFrameProcessor(max_byte=int(1e6)) processor.save(df, file_path) loaded = processor.load(file_path) pd.testing.assert_frame_equal(loaded, df, check_like=True) def test_save_and_load_empty(self): file_path = os.path.join(self.temporary_directory, 'test_with_empty.zip') df = pd.DataFrame() processor = LargeDataFrameProcessor(max_byte=int(1e6)) processor.save(df, file_path) loaded = processor.load(file_path) pd.testing.assert_frame_equal(loaded, df, check_like=True) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_list_task_instance_parameter.py ================================================ import unittest from typing import Any import luigi import gokart from gokart import TaskOnKart class _DummySubTask(TaskOnKart[Any]): task_namespace = __name__ pass class _DummyTask(TaskOnKart[Any]): task_namespace = __name__ param: luigi.IntParameter = luigi.IntParameter() task: gokart.TaskInstanceParameter[_DummySubTask] = gokart.TaskInstanceParameter(default=_DummySubTask()) class ListTaskInstanceParameterTest(unittest.TestCase): def setUp(self): _DummyTask.clear_instance_cache() def test_serialize_and_parse(self): original = [_DummyTask(param=3), _DummyTask(param=3)] s = gokart.ListTaskInstanceParameter().serialize(original) parsed = gokart.ListTaskInstanceParameter().parse(s) self.assertEqual(parsed[0].task_id, original[0].task_id) self.assertEqual(parsed[1].task_id, original[1].task_id) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_mypy.py ================================================ import tempfile import unittest from mypy import api from test.config import PYPROJECT_TOML class TestMyMypyPlugin(unittest.TestCase): def test_plugin_no_issue(self): test_code = """ import luigi from luigi import Parameter import gokart import datetime class MyTask(gokart.TaskOnKart): foo: int = luigi.IntParameter() # type: ignore bar: str = luigi.Parameter() # type: ignore baz: bool = gokart.ExplicitBoolParameter() qux: str = Parameter() # https://github.com/m3dev/gokart/issues/395 datetime: datetime.datetime = luigi.DateMinuteParameter(interval=10, default=datetime.datetime(2021, 1, 1)) # TaskOnKart parameters: # - `complete_check_at_run` MyTask(foo=1, bar='bar', baz=False, qux='qux', complete_check_at_run=False) """ with tempfile.NamedTemporaryFile(suffix='.py') as test_file: test_file.write(test_code.encode('utf-8')) test_file.flush() stdout, stderr, exitcode = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) self.assertEqual(exitcode, 0, f'mypy plugin error occurred:\nstdout: {stdout}\nstderr: {stderr}') self.assertIn('Success: no issues found', stdout) def test_plugin_invalid_arg(self): test_code = """ import luigi import gokart class MyTask(gokart.TaskOnKart): foo: int = luigi.IntParameter() # type: ignore bar: str = luigi.Parameter() # type: ignore baz: bool = gokart.ExplicitBoolParameter() # issue: foo is int # not issue: bar is missing, because it can be set by config file. # TaskOnKart parameters: # - `complete_check_at_run` MyTask(foo='1', baz='not bool', complete_check_at_run='not bool') """ with tempfile.NamedTemporaryFile(suffix='.py') as test_file: test_file.write(test_code.encode('utf-8')) test_file.flush() stdout, stderr, exitcode = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) self.assertEqual(exitcode, 1, f'mypy plugin error not occurred:\nstdout: {stdout}\nstderr: {stderr}') self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', stdout) self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', stdout) self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', stdout) self.assertIn('Found 3 errors in 1 file (checked 1 source file)', stdout) ================================================ FILE: test/test_pandas_type_check_framework.py ================================================ from __future__ import annotations import logging import unittest from logging import getLogger from typing import Any from unittest.mock import patch import luigi import pandas as pd from luigi.mock import MockFileSystem, MockTarget import gokart from gokart.build import GokartBuildError from gokart.pandas_type_config import PandasTypeConfig logger = getLogger(__name__) class TestPandasTypeConfig(PandasTypeConfig): task_namespace = 'test_pandas_type_check_framework' @classmethod def type_dict(cls) -> dict[str, Any]: return {'system_cd': int} class _DummyFailTask(gokart.TaskOnKart[pd.DataFrame]): task_namespace = 'test_pandas_type_check_framework' rerun: luigi.BoolParameter = luigi.BoolParameter(default=True, significant=False) def output(self): return self.make_target('dummy.pkl') def run(self): df = pd.DataFrame(dict(system_cd=['1'])) self.dump(df) class _DummyFailWithNoneTask(gokart.TaskOnKart[pd.DataFrame]): task_namespace = 'test_pandas_type_check_framework' rerun: luigi.BoolParameter = luigi.BoolParameter(default=True, significant=False) def output(self): return self.make_target('dummy.pkl') def run(self): df = pd.DataFrame(dict(system_cd=[1, None])) self.dump(df) class _DummySuccessTask(gokart.TaskOnKart[pd.DataFrame]): task_namespace = 'test_pandas_type_check_framework' rerun: luigi.BoolParameter = luigi.BoolParameter(default=True, significant=False) def output(self): return self.make_target('dummy.pkl') def run(self): df = pd.DataFrame(dict(system_cd=[1])) self.dump(df) class TestPandasTypeCheckFramework(unittest.TestCase): def setUp(self) -> None: luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False MockFileSystem().clear() # same way as luigi https://github.com/spotify/luigi/blob/fe7ecf4acf7cf4c084bd0f32162c8e0721567630/test/helpers.py#L175 self._stashed_reg = luigi.task_register.Register._get_reg() def tearDown(self) -> None: luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False luigi.task_register.Register._set_reg(self._stashed_reg) @patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock']) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_fail_with_gokart_run(self): with self.assertRaises(SystemExit) as exit_code: gokart.run() self.assertNotEqual(exit_code.exception.code, 0) # raise Error def test_fail(self): with self.assertRaises(GokartBuildError): gokart.build(_DummyFailTask(), log_level=logging.CRITICAL) def test_fail_with_None(self): with self.assertRaises(GokartBuildError): gokart.build(_DummyFailWithNoneTask(), log_level=logging.CRITICAL) def test_success(self): gokart.build(_DummySuccessTask()) # no error ================================================ FILE: test/test_pandas_type_config.py ================================================ from __future__ import annotations from datetime import date, datetime from typing import Any from unittest import TestCase import numpy as np import pandas as pd from gokart import PandasTypeConfig from gokart.pandas_type_config import PandasTypeError class _DummyPandasTypeConfig(PandasTypeConfig): @classmethod def type_dict(cls) -> dict[str, Any]: return {'int_column': int, 'datetime_column': datetime, 'array_column': np.ndarray} class TestPandasTypeConfig(TestCase): def test_int_fail(self): df = pd.DataFrame(dict(int_column=['1'])) with self.assertRaises(PandasTypeError): _DummyPandasTypeConfig().check(df) def test_int_success(self): df = pd.DataFrame(dict(int_column=[1])) _DummyPandasTypeConfig().check(df) def test_datetime_fail(self): df = pd.DataFrame(dict(datetime_column=[date(2019, 1, 12)])) with self.assertRaises(PandasTypeError): _DummyPandasTypeConfig().check(df) def test_datetime_success(self): df = pd.DataFrame(dict(datetime_column=[datetime(2019, 1, 12, 0, 0, 0)])) _DummyPandasTypeConfig().check(df) def test_array_fail(self): df = pd.DataFrame(dict(array_column=[[1, 2]])) with self.assertRaises(PandasTypeError): _DummyPandasTypeConfig().check(df) def test_array_success(self): df = pd.DataFrame(dict(array_column=[np.array([1, 2])])) _DummyPandasTypeConfig().check(df) ================================================ FILE: test/test_restore_task_by_id.py ================================================ import unittest from typing import Any from unittest.mock import patch import luigi import luigi.mock import gokart class _SubDummyTask(gokart.TaskOnKart[str]): task_namespace = __name__ param: luigi.IntParameter = luigi.IntParameter() def run(self): self.dump('test') class _DummyTask(gokart.TaskOnKart[str]): task_namespace = __name__ sub_task: gokart.TaskInstanceParameter[gokart.TaskOnKart[Any]] = gokart.TaskInstanceParameter() def output(self): return self.make_target('test.txt') def run(self): self.dump('test') class RestoreTaskByIDTest(unittest.TestCase): def setUp(self) -> None: luigi.mock.MockFileSystem().clear() @patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs)) def test(self): task = _DummyTask(sub_task=_SubDummyTask(param=10)) luigi.build([task], local_scheduler=True, log_level='CRITICAL') unique_id = task.make_unique_id() restored = _DummyTask.restore(unique_id) self.assertTrue(task.make_unique_id(), restored.make_unique_id()) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_run.py ================================================ import os import unittest from typing import Any from unittest.mock import MagicMock, patch import luigi import luigi.mock import gokart from gokart.run import _try_to_send_event_summary_to_slack class _DummyTask(gokart.TaskOnKart[Any]): task_namespace = __name__ param: luigi.StrParameter = luigi.StrParameter() class RunTest(unittest.TestCase): def setUp(self): luigi.configuration.LuigiConfigParser._instance = None luigi.mock.MockFileSystem().clear() os.environ.clear() @patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--param', 'test', '--log-level=CRITICAL', '--local-scheduler']) def test_run(self): config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini') luigi.configuration.LuigiConfigParser.add_config_path(config_file_path) os.environ.setdefault('test_param', 'test') with self.assertRaises(SystemExit) as exit_code: gokart.run() self.assertEqual(exit_code.exception.code, 0) @patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--log-level=CRITICAL', '--local-scheduler']) def test_run_with_undefined_environ(self): config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini') luigi.configuration.LuigiConfigParser.add_config_path(config_file_path) with self.assertRaises(luigi.parameter.MissingParameterException): gokart.run() @patch( 'sys.argv', new=[ 'main', '--tree-info-mode=simple', '--tree-info-output-path=tree.txt', f'{__name__}._DummyTask', '--param', 'test', '--log-level=CRITICAL', '--local-scheduler', ], ) @patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs)) def test_run_tree_info(self): config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini') luigi.configuration.LuigiConfigParser.add_config_path(config_file_path) os.environ.setdefault('test_param', 'test') tree_info = gokart.tree_info(mode='simple', output_path='tree.txt') with self.assertRaises(SystemExit): gokart.run() self.assertTrue(gokart.make_tree_info(_DummyTask(param='test')), tree_info.output().load()) @patch('gokart.make_tree_info') def test_try_to_send_event_summary_to_slack(self, make_tree_info_mock: MagicMock) -> None: event_aggregator_mock = MagicMock() event_aggregator_mock.get_summury.return_value = f'{__name__}._DummyTask' event_aggregator_mock.get_event_list.return_value = f'{__name__}._DummyTask:[]' make_tree_info_mock.return_value = 'tree' def get_content(content: str, **kwargs: Any) -> None: self.output = content slack_api_mock = MagicMock() slack_api_mock.send_snippet.side_effect = get_content cmdline_args = [f'{__name__}._DummyTask', '--param', 'test'] with patch('gokart.slack.SlackConfig.send_tree_info', True): _try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args) expects = os.linesep.join(['===== Event List ====', event_aggregator_mock.get_event_list(), os.linesep, '==== Tree Info ====', 'tree']) results = self.output self.assertEqual(expects, results) cmdline_args = [f'{__name__}._DummyTask', '--param', 'test'] with patch('gokart.slack.SlackConfig.send_tree_info', False): _try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args) expects = os.linesep.join( [ '===== Event List ====', event_aggregator_mock.get_event_list(), os.linesep, '==== Tree Info ====', 'Please add SlackConfig.send_tree_info to include tree-info', ] ) results = self.output self.assertEqual(expects, results) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_s3_config.py ================================================ import unittest from gokart.s3_config import S3Config class TestS3Config(unittest.TestCase): def test_get_same_s3_client(self): client_a = S3Config().get_s3_client() client_b = S3Config().get_s3_client() self.assertEqual(client_a, client_b) ================================================ FILE: test/test_s3_zip_client.py ================================================ import os import shutil import unittest import boto3 from moto import mock_aws from gokart.s3_zip_client import S3ZipClient from test.util import _get_temporary_directory class TestS3ZipClient(unittest.TestCase): def setUp(self): self.temporary_directory = _get_temporary_directory() def tearDown(self): shutil.rmtree(self.temporary_directory, ignore_errors=True) # remove temporary zip archive if exists. if os.path.exists(f'{self.temporary_directory}.zip'): os.remove(f'{self.temporary_directory}.zip') @mock_aws def test_make_archive(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') file_path = os.path.join('s3://test/', 'test.zip') temporary_directory = self.temporary_directory zip_client = S3ZipClient(file_path=file_path, temporary_directory=temporary_directory) # raise error if temporary directory does not exist. with self.assertRaises(FileNotFoundError): zip_client.make_archive() # run without error because temporary directory exists. os.makedirs(temporary_directory, exist_ok=True) zip_client.make_archive() @mock_aws def test_unpack_archive(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') file_path = os.path.join('s3://test/', 'test.zip') in_temporary_directory = os.path.join(self.temporary_directory, 'in', 'dummy') out_temporary_directory = os.path.join(self.temporary_directory, 'out', 'dummy') # make dummy zip file. os.makedirs(in_temporary_directory, exist_ok=True) in_zip_client = S3ZipClient(file_path=file_path, temporary_directory=in_temporary_directory) in_zip_client.make_archive() # load dummy zip file. out_zip_client = S3ZipClient(file_path=file_path, temporary_directory=out_temporary_directory) self.assertFalse(os.path.exists(out_temporary_directory)) out_zip_client.unpack_archive() ================================================ FILE: test/test_serializable_parameter.py ================================================ import json import tempfile from dataclasses import asdict, dataclass from typing import Any import luigi import pytest from luigi.cmdline_parser import CmdlineParser from mypy import api from gokart import SerializableParameter, TaskOnKart from test.config import PYPROJECT_TOML @dataclass(frozen=True) class Config: foo: int bar: str def gokart_serialize(self) -> str: # dict is ordered in Python 3.7+ return json.dumps(asdict(self)) @classmethod def gokart_deserialize(cls, s: str) -> 'Config': return cls(**json.loads(s)) class SerializableParameterWithOutDefault(TaskOnKart[Any]): task_namespace = __name__ config: SerializableParameter[Config] = SerializableParameter(object_type=Config) def run(self): self.dump(self.config) class SerializableParameterWithDefault(TaskOnKart[Any]): task_namespace = __name__ config: SerializableParameter[Config] = SerializableParameter(object_type=Config, default=Config(foo=1, bar='bar')) def run(self): self.dump(self.config) class TestSerializableParameter: def test_default(self): with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithDefault']) as cp: assert cp.get_task_obj().config == Config(foo=1, bar='bar') def test_parse_param(self): with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp: assert cp.get_task_obj().config == Config(foo=100, bar='val') def test_missing_parameter(self): with pytest.raises(luigi.parameter.MissingParameterException): with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault']) as cp: cp.get_task_obj() def test_value_error(self): with pytest.raises(ValueError): with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', 'Foo']) as cp: cp.get_task_obj() def test_expected_one_argument_error(self): with pytest.raises(SystemExit): with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config']) as cp: cp.get_task_obj() def test_mypy(self): """check invalid object cannot used for SerializableParameter""" test_code = """ import gokart class InvalidClass: ... gokart.SerializableParameter(object_type=InvalidClass) """ with tempfile.NamedTemporaryFile(suffix='.py') as test_file: test_file.write(test_code.encode('utf-8')) test_file.flush() result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name]) assert 'Value of type variable "S" of "SerializableParameter" cannot be "InvalidClass" [type-var]' in result[0] ================================================ FILE: test/test_target.py ================================================ import io import os import shutil import unittest from datetime import datetime from unittest.mock import patch import boto3 import numpy as np import pandas as pd from matplotlib import pyplot from moto import mock_aws from gokart.file_processor.base import _ChunkedLargeFileReader from gokart.target import make_model_target, make_target from test.util import _get_temporary_directory class LocalTargetTest(unittest.TestCase): def setUp(self): self.temporary_directory = _get_temporary_directory() def tearDown(self): shutil.rmtree(self.temporary_directory, ignore_errors=True) def test_save_and_load_pickle_file(self): obj = 1 file_path = os.path.join(self.temporary_directory, 'test.pkl') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) with unittest.mock.patch('gokart.file_processor.base._ChunkedLargeFileReader', wraps=_ChunkedLargeFileReader) as monkey: loaded = target.load() monkey.assert_called() self.assertEqual(loaded, obj) def test_save_and_load_text_file(self): obj = 1 file_path = os.path.join(self.temporary_directory, 'test.txt') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() self.assertEqual(loaded, [str(obj)], msg='should save an object as List[str].') def test_save_and_load_gzip(self): obj = 1 file_path = os.path.join(self.temporary_directory, 'test.gz') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() self.assertEqual(loaded, [str(obj)], msg='should save an object as List[str].') def test_save_and_load_npz(self): obj = np.ones(shape=10, dtype=np.float32) file_path = os.path.join(self.temporary_directory, 'test.npz') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() np.testing.assert_almost_equal(obj, loaded) def test_save_and_load_figure(self): figure_binary = io.BytesIO() pd.DataFrame(dict(x=range(10), y=range(10))).plot.scatter(x='x', y='y') pyplot.savefig(figure_binary) figure_binary.seek(0) file_path = os.path.join(self.temporary_directory, 'test.png') target = make_target(file_path=file_path, unique_id=None) target.dump(figure_binary.read()) loaded = target.load() self.assertGreater(len(loaded), 1000) # any binary def test_save_and_load_csv(self): obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) file_path = os.path.join(self.temporary_directory, 'test.csv') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() pd.testing.assert_frame_equal(loaded, obj) def test_save_and_load_tsv(self): obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) file_path = os.path.join(self.temporary_directory, 'test.tsv') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() pd.testing.assert_frame_equal(loaded, obj) def test_save_and_load_parquet(self): obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) file_path = os.path.join(self.temporary_directory, 'test.parquet') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() pd.testing.assert_frame_equal(loaded, obj) def test_save_and_load_feather(self): obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=pd.Index([33, 44], name='object_index')) file_path = os.path.join(self.temporary_directory, 'test.feather') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() pd.testing.assert_frame_equal(loaded, obj) def test_save_and_load_feather_without_store_index_in_feather(self): obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=pd.Index([33, 44], name='object_index')).reset_index() file_path = os.path.join(self.temporary_directory, 'test.feather') target = make_target(file_path=file_path, unique_id=None, store_index_in_feather=False) target.dump(obj) loaded = target.load() pd.testing.assert_frame_equal(loaded, obj) def test_last_modified_time(self): obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) file_path = os.path.join(self.temporary_directory, 'test.csv') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) t = target.last_modification_time() self.assertIsInstance(t, datetime) def test_last_modified_time_without_file(self): file_path = os.path.join(self.temporary_directory, 'test.csv') target = make_target(file_path=file_path, unique_id=None) with self.assertRaises(FileNotFoundError): target.last_modification_time() def test_save_pandas_series(self): obj = pd.Series(data=[1, 2], name='column_name') file_path = os.path.join(self.temporary_directory, 'test.csv') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() pd.testing.assert_series_equal(loaded['column_name'], obj) def test_dump_with_lock(self): with patch('gokart.target.wrap_dump_with_lock') as wrap_with_lock_mock: obj = 1 file_path = os.path.join(self.temporary_directory, 'test.pkl') target = make_target(file_path=file_path, unique_id=None) target.dump(obj, lock_at_dump=True) wrap_with_lock_mock.assert_called_once() def test_dump_without_lock(self): with patch('gokart.target.wrap_dump_with_lock') as wrap_with_lock_mock: obj = 1 file_path = os.path.join(self.temporary_directory, 'test.pkl') target = make_target(file_path=file_path, unique_id=None) target.dump(obj, lock_at_dump=False) wrap_with_lock_mock.assert_not_called() class S3TargetTest(unittest.TestCase): @mock_aws def test_save_on_s3(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') obj = 1 file_path = os.path.join('s3://test/', 'test.pkl') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() self.assertEqual(loaded, obj) @mock_aws def test_last_modified_time(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') obj = 1 file_path = os.path.join('s3://test/', 'test.pkl') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) t = target.last_modification_time() self.assertIsInstance(t, datetime) @mock_aws def test_last_modified_time_without_file(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') file_path = os.path.join('s3://test/', 'test.pkl') target = make_target(file_path=file_path, unique_id=None) with self.assertRaises(FileNotFoundError): target.last_modification_time() @mock_aws def test_save_on_s3_feather(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) file_path = os.path.join('s3://test/', 'test.feather') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() pd.testing.assert_frame_equal(loaded, obj) @mock_aws def test_save_on_s3_parquet(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4])) file_path = os.path.join('s3://test/', 'test.parquet') target = make_target(file_path=file_path, unique_id=None) target.dump(obj) loaded = target.load() pd.testing.assert_frame_equal(loaded, obj) class ModelTargetTest(unittest.TestCase): def setUp(self): self.temporary_directory = _get_temporary_directory() def tearDown(self): shutil.rmtree(self.temporary_directory, ignore_errors=True) @staticmethod def _save_function(obj, path): make_target(file_path=path).dump(obj) @staticmethod def _load_function(path): return make_target(file_path=path).load() def test_model_target_on_local(self): obj = 1 file_path = os.path.join(self.temporary_directory, 'test.zip') target = make_model_target( file_path=file_path, temporary_directory=self.temporary_directory, save_function=self._save_function, load_function=self._load_function ) target.dump(obj) loaded = target.load() self.assertEqual(loaded, obj) @mock_aws def test_model_target_on_s3(self): conn = boto3.resource('s3', region_name='us-east-1') conn.create_bucket(Bucket='test') obj = 1 file_path = os.path.join('s3://test/', 'test.zip') target = make_model_target( file_path=file_path, temporary_directory=self.temporary_directory, save_function=self._save_function, load_function=self._load_function ) target.dump(obj) loaded = target.load() self.assertEqual(loaded, obj) if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_task_instance_parameter.py ================================================ import unittest from typing import Any import luigi import gokart from gokart import ListTaskInstanceParameter, TaskInstanceParameter, TaskOnKart class _DummySubTask(TaskOnKart[Any]): task_namespace = __name__ pass class _DummyCorrectSubClassTask(_DummySubTask): task_namespace = __name__ pass class _DummyInvalidSubClassTask(TaskOnKart[Any]): task_namespace = __name__ pass class _DummyTask(TaskOnKart[Any]): task_namespace = __name__ param: luigi.IntParameter = luigi.IntParameter() task: TaskInstanceParameter[_DummySubTask] = TaskInstanceParameter(default=_DummySubTask()) class _DummyListTask(TaskOnKart[Any]): task_namespace = __name__ param: luigi.IntParameter = luigi.IntParameter() task: ListTaskInstanceParameter[_DummySubTask] = ListTaskInstanceParameter(default=[_DummySubTask(), _DummySubTask()]) class TaskInstanceParameterTest(unittest.TestCase): def setUp(self): _DummyTask.clear_instance_cache() def test_serialize_and_parse(self): original = _DummyTask(param=2) s = gokart.TaskInstanceParameter().serialize(original) parsed = gokart.TaskInstanceParameter().parse(s) self.assertEqual(parsed.task_id, original.task_id) def test_serialize_and_parse_list_params(self): original = _DummyListTask(param=2) s = gokart.TaskInstanceParameter().serialize(original) parsed = gokart.TaskInstanceParameter().parse(s) self.assertEqual(parsed.task_id, original.task_id) def test_invalid_class(self): self.assertRaises(TypeError, lambda: gokart.TaskInstanceParameter(expected_type=1)) # type: ignore def test_params_with_correct_param_type(self): class _DummyPipelineA(TaskOnKart[Any]): task_namespace = __name__ subtask: gokart.TaskInstanceParameter[_DummySubTask] = gokart.TaskInstanceParameter(expected_type=_DummySubTask) task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask()) self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask()) # type: ignore def test_params_with_invalid_param_type(self): class _DummyPipelineB(TaskOnKart[Any]): task_namespace = __name__ subtask: gokart.TaskInstanceParameter[_DummySubTask] = gokart.TaskInstanceParameter(expected_type=_DummySubTask) with self.assertRaises(TypeError): _DummyPipelineB(subtask=_DummyInvalidSubClassTask()) # type: ignore class ListTaskInstanceParameterTest(unittest.TestCase): def setUp(self): _DummyTask.clear_instance_cache() def test_invalid_class(self): self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_elements_type=1)) # type: ignore # not type instance def test_list_params_with_correct_param_types(self): class _DummyPipelineC(TaskOnKart[Any]): task_namespace = __name__ subtask: gokart.ListTaskInstanceParameter[_DummySubTask] = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()]) self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(),)) # type: ignore def test_list_params_with_invalid_param_types(self): class _DummyPipelineD(TaskOnKart[Any]): task_namespace = __name__ subtask: gokart.ListTaskInstanceParameter[_DummySubTask] = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask) with self.assertRaises(TypeError): _DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()]) # type: ignore if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_task_on_kart.py ================================================ from __future__ import annotations import os import pathlib import unittest from datetime import datetime from typing import Any, cast from unittest.mock import Mock, patch import luigi import pandas as pd from luigi.parameter import ParameterVisibility from luigi.util import inherits import gokart from gokart.file_processor import XmlFileProcessor from gokart.parameter import ListTaskInstanceParameter, TaskInstanceParameter from gokart.target import ModelTarget, SingleFileTarget, TargetOnKart from gokart.task import EmptyDumpError class _DummyTask(gokart.TaskOnKart[Any]): task_namespace = __name__ param: luigi.IntParameter = luigi.IntParameter(default=1) list_param: luigi.ListParameter[tuple[str, ...]] = luigi.ListParameter(default=('a', 'b')) bool_param: luigi.BoolParameter = luigi.BoolParameter() def output(self): return None class _DummyTaskA(gokart.TaskOnKart[Any]): task_namespace = __name__ def output(self): return None @inherits(_DummyTaskA) class _DummyTaskB(gokart.TaskOnKart[Any]): task_namespace = __name__ def output(self): return None def requires(self): return self.clone(_DummyTaskA) @inherits(_DummyTaskB) class _DummyTaskC(gokart.TaskOnKart[Any]): task_namespace = __name__ def output(self): return None def requires(self): return self.clone(_DummyTaskB) class _DummyTaskD(gokart.TaskOnKart[Any]): task_namespace = __name__ class _DummyTaskWithoutLock(gokart.TaskOnKart[Any]): task_namespace = __name__ def run(self): pass class _DummySubTaskWithPrivateParameter(gokart.TaskOnKart[Any]): task_namespace = __name__ class _DummyTaskWithPrivateParameter(gokart.TaskOnKart[Any]): task_namespace = __name__ int_param: luigi.IntParameter = luigi.IntParameter() private_int_param: luigi.IntParameter = luigi.IntParameter(visibility=ParameterVisibility.PRIVATE) task_param: TaskInstanceParameter[Any] = TaskInstanceParameter() list_task_param: ListTaskInstanceParameter[Any] = ListTaskInstanceParameter() class TaskTest(unittest.TestCase): def setUp(self): _DummyTask.clear_instance_cache() _DummyTaskA.clear_instance_cache() _DummyTaskB.clear_instance_cache() _DummyTaskC.clear_instance_cache() def test_complete_without_dependency(self): task = _DummyTask() self.assertTrue(task.complete(), msg='_DummyTask does not have any output files, so this always must be completed.') def test_complete_with_rerun_flag(self): task = _DummyTask(rerun=True) self.assertFalse(task.complete(), msg='"rerun" flag force tasks rerun once.') self.assertTrue(task.complete(), msg='"rerun" flag should be changed.') def test_complete_with_uncompleted_input(self): uncompleted_target = Mock(spec=TargetOnKart) uncompleted_target.exists.return_value = False # depends on an uncompleted target. task = _DummyTask() task.input = Mock(return_value=uncompleted_target) # type: ignore self.assertTrue(task.complete(), msg='task does not care input targets.') # make a task check its inputs. task.strict_check = True self.assertFalse(task.complete()) def test_complete_with_modified_input(self): input_target = Mock(spec=TargetOnKart) input_target.exists.return_value = True input_target.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0) output_target = Mock(spec=TargetOnKart) output_target.exists.return_value = True output_target.last_modification_time.return_value = datetime(2018, 1, 1, 9, 0, 0) # depends on an uncompleted target. task = _DummyTask() task.modification_time_check = False task.input = Mock(return_value=input_target) # type: ignore task.output = Mock(return_value=output_target) # type: ignore self.assertTrue(task.complete(), msg='task does not care modified time') # make a task check its inputs. task.modification_time_check = True self.assertFalse(task.complete()) def test_complete_when_modification_time_equals_output(self): """Test the case that modification time of input equals that of output. The case is occurred when input and output targets are same. """ input_target = Mock(spec=TargetOnKart) input_target.exists.return_value = True input_target.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0) output_target = Mock(spec=TargetOnKart) output_target.exists.return_value = True output_target.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0) task = _DummyTask() task.modification_time_check = True task.input = Mock(return_value=input_target) # type: ignore task.output = Mock(return_value=output_target) # type: ignore self.assertTrue(task.complete()) def test_complete_when_input_and_output_equal(self): target1 = Mock(spec=TargetOnKart) target1.exists.return_value = True target1.path.return_value = 'path1.pkl' target1.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0) target2 = Mock(spec=TargetOnKart) target2.exists.return_value = True target2.path.return_value = 'path2.pkl' target2.last_modification_time.return_value = datetime(2018, 1, 1, 9, 0, 0) target3 = Mock(spec=TargetOnKart) target3.exists.return_value = True target3.path.return_value = 'path3.pkl' target3.last_modification_time.return_value = datetime(2018, 1, 1, 9, 0, 0) task = _DummyTask() task.modification_time_check = True task.input = Mock(return_value=[target1, target2]) # type: ignore task.output = Mock(return_value=[target1, target2]) # type: ignore self.assertTrue(task.complete()) task.input = Mock(return_value=[target1, target2]) # type: ignore task.output = Mock(return_value=[target2, target3]) # type: ignore self.assertFalse(task.complete()) def test_default_target(self): task = _DummyTaskD() default_target = task.output() self.assertIsInstance(default_target, SingleFileTarget) self.assertEqual(f'_DummyTaskD_{task.task_unique_id}.pkl', pathlib.Path(default_target._target.path).name) # type: ignore def test_clone_with_special_params(self): class _DummyTaskRerun(gokart.TaskOnKart[Any]): a: luigi.BoolParameter = luigi.BoolParameter(default=False) task = _DummyTaskRerun(a=True, rerun=True) cloned = task.clone(_DummyTaskRerun) cloned_with_explicit_rerun = task.clone(_DummyTaskRerun, rerun=True) self.assertTrue(cloned.a) self.assertFalse(cloned.rerun) # do not clone rerun self.assertTrue(cloned_with_explicit_rerun.a) self.assertTrue(cloned_with_explicit_rerun.rerun) def test_default_large_dataframe_target(self): task = _DummyTaskD() default_large_dataframe_target = task.make_large_data_frame_target() self.assertIsInstance(default_large_dataframe_target, ModelTarget) target = cast(ModelTarget, default_large_dataframe_target) self.assertEqual(f'_DummyTaskD_{task.task_unique_id}.zip', pathlib.Path(target._zip_client.path).name) def test_make_target(self): task = _DummyTask() target = task.make_target('test.txt') self.assertIsInstance(target, SingleFileTarget) def test_make_target_without_id(self): path = _DummyTask().make_target('test.txt', use_unique_id=False).path() self.assertEqual(path, os.path.join(_DummyTask().workspace_directory, 'test.txt')) def test_make_target_with_processor(self): task = _DummyTask() processor = XmlFileProcessor() target = task.make_target('test.dummy', processor=processor) self.assertIsInstance(target, SingleFileTarget) target = cast(SingleFileTarget, target) self.assertEqual(target._processor, processor) def test_get_own_code(self): task = _DummyTask() task_scripts = 'def output(self):\nreturn None\n' self.assertEqual(task.get_own_code().replace(' ', ''), task_scripts.replace(' ', '')) def test_make_unique_id_with_own_code(self): class _MyDummyTaskA(gokart.TaskOnKart[str]): _visible_in_registry = False def run(self): self.dump('Hello, world!') task_unique_id = _MyDummyTaskA(serialized_task_definition_check=False).make_unique_id() task_with_code_unique_id = _MyDummyTaskA(serialized_task_definition_check=True).make_unique_id() self.assertNotEqual(task_unique_id, task_with_code_unique_id) class _MyDummyTaskA(gokart.TaskOnKart[str]): # type: ignore _visible_in_registry = False def run(self): modified_code = 'modify!!' self.dump(modified_code) task_modified_unique_id = _MyDummyTaskA(serialized_task_definition_check=False).make_unique_id() task_modified_with_code_unique_id = _MyDummyTaskA(serialized_task_definition_check=True).make_unique_id() self.assertEqual(task_modified_unique_id, task_unique_id) self.assertNotEqual(task_modified_with_code_unique_id, task_with_code_unique_id) def test_compare_targets_of_different_tasks(self): path1 = _DummyTask(param=1).make_target('test.txt').path() path2 = _DummyTask(param=2).make_target('test.txt').path() self.assertNotEqual(path1, path2, msg='different tasks must generate different targets.') def test_make_model_target(self): task = _DummyTask() target = task.make_model_target('test.zip', save_function=Mock(), load_function=Mock()) self.assertIsInstance(target, ModelTarget) def test_load_with_single_target(self): task = _DummyTask() target = Mock(spec=TargetOnKart) target.load.return_value = 1 task.input = Mock(return_value=target) # type: ignore data = task.load() target.load.assert_called_once() self.assertEqual(data, 1) def test_load_with_single_dict_target(self): task = _DummyTask() target = Mock(spec=TargetOnKart) target.load.return_value = 1 task.input = Mock(return_value={'target_key': target}) # type: ignore data = task.load() target.load.assert_called_once() self.assertEqual(data, {'target_key': 1}) def test_load_with_keyword(self): task = _DummyTask() target = Mock(spec=TargetOnKart) target.load.return_value = 1 task.input = Mock(return_value={'target_key': target}) # type: ignore data = task.load('target_key') target.load.assert_called_once() self.assertEqual(data, 1) def test_load_tuple(self): task = _DummyTask() target1 = Mock(spec=TargetOnKart) target1.load.return_value = 1 target2 = Mock(spec=TargetOnKart) target2.load.return_value = 2 task.input = Mock(return_value=(target1, target2)) # type: ignore data = task.load() target1.load.assert_called_once() target2.load.assert_called_once() self.assertEqual(data[0], 1) self.assertEqual(data[1], 2) def test_load_dictionary_at_once(self): task = _DummyTask() target1 = Mock(spec=TargetOnKart) target1.load.return_value = 1 target2 = Mock(spec=TargetOnKart) target2.load.return_value = 2 task.input = Mock(return_value={'target_key_1': target1, 'target_key_2': target2}) # type: ignore data = task.load() target1.load.assert_called_once() target2.load.assert_called_once() self.assertEqual(data['target_key_1'], 1) self.assertEqual(data['target_key_2'], 2) def test_load_with_task_on_kart(self): task = _DummyTask() task2 = Mock(spec=gokart.TaskOnKart) task2.make_unique_id.return_value = 'task2' task2_output = Mock(spec=TargetOnKart) task2.output.return_value = task2_output task2_output.load.return_value = 1 # task2 should be in requires' return values task.requires = lambda: {'task2': task2} # type: ignore actual = task.load(task2) self.assertEqual(actual, 1) def test_load_with_task_on_kart_should_fail_when_task_on_kart_is_not_in_requires(self): """ if load args is not in requires, it should raise an error. """ task = _DummyTask() task2 = Mock(spec=gokart.TaskOnKart) task2_output = Mock(spec=TargetOnKart) task2.output.return_value = task2_output task2_output.load.return_value = 1 with self.assertRaises(AssertionError): task.load(task2) def test_load_with_task_on_kart_list(self): task = _DummyTask() task2 = Mock(spec=gokart.TaskOnKart) task2.make_unique_id.return_value = 'task2' task2_output = Mock(spec=TargetOnKart) task2.output.return_value = task2_output task2_output.load.return_value = 1 task3 = Mock(spec=gokart.TaskOnKart) task3.make_unique_id.return_value = 'task3' task3_output = Mock(spec=TargetOnKart) task3.output.return_value = task3_output task3_output.load.return_value = 2 # task2 should be in requires' return values task.requires = lambda: {'tasks': [task2, task3]} # type: ignore load_args: list[gokart.TaskOnKart[int]] = [task2, task3] actual = task.load(load_args) self.assertEqual(actual, [1, 2]) def test_load_generator_with_single_target(self): task = _DummyTask() target = Mock(spec=TargetOnKart) target.load.return_value = [1, 2] task.input = Mock(return_value=target) # type: ignore data = [x for x in task.load_generator()] self.assertEqual(data, [[1, 2]]) def test_load_generator_with_keyword(self): task = _DummyTask() target = Mock(spec=TargetOnKart) target.load.return_value = [1, 2] task.input = Mock(return_value={'target_key': target}) # type: ignore data = [x for x in task.load_generator('target_key')] self.assertEqual(data, [[1, 2]]) def test_load_generator_with_list_task_on_kart(self): task = _DummyTask() task2 = Mock(spec=gokart.TaskOnKart) task2.make_unique_id.return_value = 'task2' task2_output = Mock(spec=TargetOnKart) task2.output.return_value = task2_output task2_output.load.return_value = 1 task3 = Mock(spec=gokart.TaskOnKart) task3.make_unique_id.return_value = 'task3' task3_output = Mock(spec=TargetOnKart) task3.output.return_value = task3_output task3_output.load.return_value = 2 # task2 should be in requires' return values task.requires = lambda: {'tasks': [task2, task3]} # type: ignore load_args: list[gokart.TaskOnKart[int]] = [task2, task3] actual = [x for x in task.load_generator(load_args)] self.assertEqual(actual, [1, 2]) def test_dump(self): task = _DummyTask() target = Mock(spec=TargetOnKart) task.output = Mock(return_value=target) # type: ignore task.dump(1) target.dump.assert_called_once() def test_fail_on_empty_dump(self): # do not fail task = _DummyTask(fail_on_empty_dump=False) target = Mock(spec=TargetOnKart) task.output = Mock(return_value=target) # type: ignore task.dump(pd.DataFrame()) target.dump.assert_called_once() # fail task = _DummyTask(fail_on_empty_dump=True) self.assertRaises(EmptyDumpError, lambda: task.dump(pd.DataFrame())) @patch('luigi.configuration.get_config') def test_add_configuration(self, mock_config: Mock) -> None: mock_config.return_value = {'_DummyTask': {'list_param': '["c", "d"]', 'param': '3', 'bool_param': 'True'}} kwargs: dict[str, Any] = dict() _DummyTask._add_configuration(kwargs, '_DummyTask') self.assertEqual(3, kwargs['param']) self.assertEqual(['c', 'd'], list(kwargs['list_param'])) self.assertEqual(True, kwargs['bool_param']) @patch('luigi.cmdline_parser.CmdlineParser.get_instance') def test_add_cofigureation_evaluation_order(self, mock_cmdline: Mock) -> None: """ in case TaskOnKart._add_configuration will break evaluation order @see https://luigi.readthedocs.io/en/stable/parameters.html#parameter-resolution-order """ class DummyTaskAddConfiguration(gokart.TaskOnKart[Any]): aa = luigi.IntParameter() luigi.configuration.get_config().set('DummyTaskAddConfiguration', 'aa', '3') mock_cmdline.return_value = luigi.cmdline_parser.CmdlineParser(['DummyTaskAddConfiguration']) self.assertEqual(DummyTaskAddConfiguration().aa, 3) mock_cmdline.return_value = luigi.cmdline_parser.CmdlineParser(['DummyTaskAddConfiguration', '--DummyTaskAddConfiguration-aa', '2']) self.assertEqual(DummyTaskAddConfiguration().aa, 2) def test_use_rerun_with_inherits(self): # All tasks are completed. task_c = _DummyTaskC() self.assertTrue(task_c.complete()) self.assertTrue(task_c.requires().complete()) # This is an instance of TaskB. self.assertTrue(task_c.requires().requires().complete()) # This is an instance of TaskA. luigi.configuration.get_config().set(f'{__name__}._DummyTaskB', 'rerun', 'True') task_c = _DummyTaskC() self.assertTrue(task_c.complete()) self.assertFalse(task_c.requires().complete()) # This is an instance of _DummyTaskB. self.assertTrue(task_c.requires().requires().complete()) # This is an instance of _DummyTaskA. # All tasks are not completed, because _DummyTaskC.rerun = True. task_c = _DummyTaskC(rerun=True) self.assertFalse(task_c.complete()) self.assertTrue(task_c.requires().complete()) # This is an instance of _DummyTaskB. self.assertTrue(task_c.requires().requires().complete()) # This is an instance of _DummyTaskA. def test_significant_flag(self) -> None: def _make_task(significant: bool, has_required_task: bool) -> gokart.TaskOnKart[Any]: class _MyDummyTaskA(gokart.TaskOnKart[Any]): task_namespace = f'{__name__}_{significant}_{has_required_task}' class _MyDummyTaskB(gokart.TaskOnKart[Any]): task_namespace = f'{__name__}_{significant}_{has_required_task}' def requires(self): if has_required_task: return _MyDummyTaskA(significant=significant) return return _MyDummyTaskB() x_task = _make_task(significant=True, has_required_task=True) y_task = _make_task(significant=False, has_required_task=True) z_task = _make_task(significant=False, has_required_task=False) self.assertNotEqual(x_task.make_unique_id(), y_task.make_unique_id()) self.assertEqual(y_task.make_unique_id(), z_task.make_unique_id()) def test_default_requires(self): class _WithoutTaskInstanceParameter(gokart.TaskOnKart[Any]): task_namespace = __name__ class _WithTaskInstanceParameter(gokart.TaskOnKart[Any]): task_namespace = __name__ a_task: gokart.TaskInstanceParameter[Any] = gokart.TaskInstanceParameter() without_task = _WithoutTaskInstanceParameter() self.assertListEqual(without_task.requires(), []) # type: ignore with_task = _WithTaskInstanceParameter(a_task=without_task) self.assertEqual(with_task.requires()['a_task'], without_task) # type: ignore def test_repr(self): task = _DummyTaskWithPrivateParameter( int_param=1, private_int_param=1, task_param=_DummySubTaskWithPrivateParameter(), list_task_param=[_DummySubTaskWithPrivateParameter(), _DummySubTaskWithPrivateParameter()], ) task_id = task.make_unique_id() sub_task_id = _DummySubTaskWithPrivateParameter().make_unique_id() expected = ( f'{__name__}._DummyTaskWithPrivateParameter[{task_id}](int_param=1, private_int_param=1, task_param={__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), ' f'list_task_param=[{__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), {__name__}._DummySubTaskWithPrivateParameter({sub_task_id})])' ) # noqa:E501 self.assertEqual(expected, repr(task)) def test_str(self): task = _DummyTaskWithPrivateParameter( int_param=1, private_int_param=1, task_param=_DummySubTaskWithPrivateParameter(), list_task_param=[_DummySubTaskWithPrivateParameter(), _DummySubTaskWithPrivateParameter()], ) task_id = task.make_unique_id() sub_task_id = _DummySubTaskWithPrivateParameter().make_unique_id() expected = ( f'{__name__}._DummyTaskWithPrivateParameter[{task_id}](int_param=1, task_param={__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), ' f'list_task_param=[{__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), {__name__}._DummySubTaskWithPrivateParameter({sub_task_id})])' ) self.assertEqual(expected, str(task)) def test_is_task_on_kart(self): self.assertEqual(True, gokart.TaskOnKart.is_task_on_kart(gokart.TaskOnKart())) self.assertEqual(False, gokart.TaskOnKart.is_task_on_kart(1)) self.assertEqual(False, gokart.TaskOnKart.is_task_on_kart(list())) self.assertEqual(True, gokart.TaskOnKart.is_task_on_kart((gokart.TaskOnKart(), gokart.TaskOnKart()))) def test_serialize_and_deserialize_default_values(self): task: gokart.TaskOnKart[Any] = gokart.TaskOnKart() deserialized: gokart.TaskOnKart[Any] = luigi.task_register.load_task(None, task.get_task_family(), task.to_str_params()) self.assertDictEqual(task.to_str_params(), deserialized.to_str_params()) def test_to_str_params_changes_on_values_and_flags(self): class _DummyTaskWithParams(gokart.TaskOnKart[Any]): task_namespace = __name__ param: luigi.Parameter = luigi.Parameter() t1 = _DummyTaskWithParams(param='a') self.assertEqual(t1.to_str_params(), t1.to_str_params()) # cache self.assertEqual(t1.to_str_params(), _DummyTaskWithParams(param='a').to_str_params()) # same value self.assertNotEqual(t1.to_str_params(), _DummyTaskWithParams(param='b').to_str_params()) # different value self.assertNotEqual(t1.to_str_params(), t1.to_str_params(only_significant=True)) def test_should_lock_run_when_set(self): class _DummyTaskWithLock(gokart.TaskOnKart[str]): def run(self): self.dump('hello') task = _DummyTaskWithLock(redis_host='host', redis_port=123, redis_timeout=180, should_lock_run=True) self.assertEqual(task.run.__wrapped__.__name__, 'run') # type: ignore def test_should_fail_lock_run_when_host_unset(self): with self.assertRaises(AssertionError): gokart.TaskOnKart(redis_port=123, redis_timeout=180, should_lock_run=True) def test_should_fail_lock_run_when_port_unset(self): with self.assertRaises(AssertionError): gokart.TaskOnKart(redis_host='host', redis_timeout=180, should_lock_run=True) class _DummyTaskWithNonCompleted(gokart.TaskOnKart[Any]): def dump(self, _obj: Any, _target: Any = None, _custom_labels: Any = None) -> None: # overrive dump() to do nothing. pass def run(self): self.dump('hello') def complete(self): return False class _DummyTaskWithCompleted(gokart.TaskOnKart[Any]): def dump(self, obj: Any, _target: Any = None, custom_labels: Any = None) -> None: # overrive dump() to do nothing. pass def run(self): self.dump('hello') def complete(self): return True class TestCompleteCheckAtRun(unittest.TestCase): def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self): task = _DummyTaskWithNonCompleted(complete_check_at_run=False) task.dump = Mock() # type: ignore task.run() # since run() is called, dump() should be called. task.dump.assert_called_once() def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self): task = _DummyTaskWithCompleted(complete_check_at_run=False) task.dump = Mock() # type: ignore task.run() # even task is completed, since run() is called, dump() should be called. task.dump.assert_called_once() def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self): task = _DummyTaskWithNonCompleted(complete_check_at_run=True) task.dump = Mock() # type: ignore task.run() # since task is not completed, when run() is called, dump() should be called. task.dump.assert_called_once() def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self): task = _DummyTaskWithCompleted(complete_check_at_run=True) task.dump = Mock() # type: ignore task.run() # since task is completed, even when run() is called, dump() should not be called. task.dump.assert_not_called() if __name__ == '__main__': unittest.main() ================================================ FILE: test/test_utils.py ================================================ import unittest from typing import TYPE_CHECKING import pandas as pd import pytest from gokart.task import TaskOnKart from gokart.utils import flatten, get_dataframe_type_from_task, map_flattenable_items if TYPE_CHECKING: import polars as pl try: import polars as pl HAS_POLARS = True except ImportError: HAS_POLARS = False class TestFlatten(unittest.TestCase): def test_flatten_dict(self): self.assertEqual(flatten({'a': 'foo', 'b': 'bar'}), ['foo', 'bar']) def test_flatten_list(self): self.assertEqual(flatten(['foo', ['bar', 'troll']]), ['foo', 'bar', 'troll']) def test_flatten_str(self): self.assertEqual(flatten('foo'), ['foo']) def test_flatten_int(self): self.assertEqual(flatten(42), [42]) def test_flatten_none(self): self.assertEqual(flatten(None), []) class TestMapFlatten(unittest.TestCase): def test_map_flattenable_items(self): self.assertEqual(map_flattenable_items(lambda x: str(x), {'a': 1, 'b': 2}), {'a': '1', 'b': '2'}) self.assertEqual( map_flattenable_items(lambda x: str(x), (1, 2, 3, (4, 5, (6, 7, {'a': (8, 9, 0)})))), ('1', '2', '3', ('4', '5', ('6', '7', {'a': ('8', '9', '0')}))), ) self.assertEqual( map_flattenable_items( lambda x: str(x), {'a': [1, 2, 3, '4'], 'b': {'c': True, 'd': {'e': 5}}}, ), {'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}}, ) class TestGetDataFrameTypeFromTask(unittest.TestCase): """Tests for get_dataframe_type_from_task function.""" def test_pandas_dataframe_from_instance(self): """Test detecting pandas DataFrame from task instance.""" class _PandasTaskInstance(TaskOnKart[pd.DataFrame]): pass task = _PandasTaskInstance() self.assertEqual(get_dataframe_type_from_task(task), 'pandas') def test_pandas_dataframe_from_class(self): """Test detecting pandas DataFrame from task class.""" class _PandasTaskClass(TaskOnKart[pd.DataFrame]): pass self.assertEqual(get_dataframe_type_from_task(_PandasTaskClass), 'pandas') @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') def test_polars_dataframe_from_instance(self): """Test detecting polars DataFrame from task instance.""" class _PolarsTaskInstance(TaskOnKart[pl.DataFrame]): pass task = _PolarsTaskInstance() self.assertEqual(get_dataframe_type_from_task(task), 'polars') @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') def test_polars_dataframe_from_class(self): """Test detecting polars DataFrame from task class.""" class _PolarsTaskClass(TaskOnKart[pl.DataFrame]): pass self.assertEqual(get_dataframe_type_from_task(_PolarsTaskClass), 'polars') def test_no_type_parameter_defaults_to_pandas(self): """Test that tasks without type parameter default to pandas.""" # Create a class without __orig_bases__ by not using type parameters class PlainTask: pass task = PlainTask() self.assertEqual(get_dataframe_type_from_task(task), 'pandas') def test_non_taskonkart_class_defaults_to_pandas(self): """Test that non-TaskOnKart classes default to pandas.""" class RegularClass: pass task = RegularClass() self.assertEqual(get_dataframe_type_from_task(task), 'pandas') def test_taskonkart_with_non_dataframe_type(self): """Test TaskOnKart with non-DataFrame type parameter defaults to pandas.""" class _StringTask(TaskOnKart[str]): pass task = _StringTask() # Should default to pandas since str module is not 'pandas' or 'polars' self.assertEqual(get_dataframe_type_from_task(task), 'pandas') def test_nested_inheritance_pandas(self): """Test that nested inheritance without direct type parameter defaults to pandas.""" class _BasePandasTask(TaskOnKart[pd.DataFrame]): pass class _DerivedPandasTask(_BasePandasTask): pass task = _DerivedPandasTask() # _DerivedPandasTask doesn't have its own __orig_bases__ with type parameter, # so it defaults to 'pandas' self.assertEqual(get_dataframe_type_from_task(task), 'pandas') @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') def test_nested_inheritance_polars(self): """Test detecting polars DataFrame type through nested inheritance.""" class _BasePolarsTask(TaskOnKart[pl.DataFrame]): pass class _DerivedPolarsTask(_BasePolarsTask): pass task = _DerivedPolarsTask() # Function should detect 'polars' through the inheritance chain self.assertEqual(get_dataframe_type_from_task(task), 'polars') @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') def test_polars_lazyframe_from_instance(self): class _LazyTaskInstance(TaskOnKart[pl.LazyFrame]): pass task = _LazyTaskInstance() self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy') @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') def test_polars_lazyframe_from_class(self): class _LazyTaskClass(TaskOnKart[pl.LazyFrame]): pass self.assertEqual(get_dataframe_type_from_task(_LazyTaskClass), 'polars-lazy') @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') def test_nested_inheritance_polars_lazyframe(self): class _BaseLazyTask(TaskOnKart[pl.LazyFrame]): pass class _DerivedLazyTask(_BaseLazyTask): pass task = _DerivedLazyTask() self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy') @pytest.mark.skipif(not HAS_POLARS, reason='polars not installed') def test_nested_inheritance_polars_with_mixin(self): """Derived class with multiple bases should still detect polars through MRO.""" class _Mixin: pass class _BasePolarsTaskWithMixin(TaskOnKart[pl.DataFrame]): pass # Multiple inheritance gives _DerivedTask its own __orig_bases__, # which shadows the parent's and doesn't contain TaskOnKart[...]. class _DerivedTaskWithMixin(_BasePolarsTaskWithMixin, _Mixin): pass task = _DerivedTaskWithMixin() self.assertEqual(get_dataframe_type_from_task(task), 'polars') ================================================ FILE: test/test_worker.py ================================================ import uuid from unittest.mock import Mock import luigi import luigi.worker import pytest from luigi import scheduler import gokart from gokart.worker import Worker, gokart_worker class _DummyTask(gokart.TaskOnKart[str]): task_namespace = __name__ random_id: luigi.StrParameter = luigi.StrParameter() def _run(self): ... def run(self): self._run() self.dump('test') class TestWorkerRun: def test_run(self, monkeypatch: pytest.MonkeyPatch) -> None: """Check run is called when the task is not completed""" sch = scheduler.Scheduler() worker = Worker(scheduler=sch) task = _DummyTask(random_id=uuid.uuid4().hex) mock_run = Mock() monkeypatch.setattr(task, '_run', mock_run) with worker: assert worker.add(task) assert worker.run() mock_run.assert_called_once() class _DummyTaskToCheckSkip(gokart.TaskOnKart[None]): task_namespace = __name__ def _run(self): ... def run(self): self._run() self.dump(None) def complete(self) -> bool: return False class TestWorkerSkipIfCompletedPreRun: @pytest.mark.parametrize( 'task_completion_check_at_run,is_completed,expect_skipped', [ pytest.param(True, True, True, id='skipped when completed and task_completion_check_at_run is True'), pytest.param(True, False, False, id='not skipped when not completed and task_completion_check_at_run is True'), pytest.param(False, True, False, id='not skipped when completed and task_completion_check_at_run is False'), pytest.param(False, False, False, id='not skipped when not completed and task_completion_check_at_run is False'), ], ) def test_skip_task(self, monkeypatch: pytest.MonkeyPatch, task_completion_check_at_run: bool, is_completed: bool, expect_skipped: bool) -> None: sch = scheduler.Scheduler() worker = Worker(scheduler=sch, config=gokart_worker(task_completion_check_at_run=task_completion_check_at_run)) mock_complete = Mock(return_value=is_completed) # NOTE: set `complete_check_at_run=False` to avoid using deprecated skip logic. task = _DummyTaskToCheckSkip(complete_check_at_run=False) mock_run = Mock() monkeypatch.setattr(task, '_run', mock_run) with worker: assert worker.add(task) # NOTE: mock `complete` after `add` because `add` calls `complete` # to check if the task is already completed. monkeypatch.setattr(task, 'complete', mock_complete) assert worker.run() if expect_skipped: mock_run.assert_not_called() else: mock_run.assert_called_once() class TestWorkerCheckCompleteValue: def test_does_not_raise_for_boolean_values(self) -> None: worker = Worker(scheduler=scheduler.Scheduler()) worker._check_complete_value(True) worker._check_complete_value(False) def test_raises_async_completion_exception_for_traceback_wrapper(self) -> None: # NOTE: When Task.complete() raises in an async check, the exception is wrapped # in TracebackWrapper. This branch must raise AsyncCompletionException. worker = Worker(scheduler=scheduler.Scheduler()) wrapped = luigi.worker.TracebackWrapper(trace='dummy traceback') with pytest.raises(luigi.worker.AsyncCompletionException): worker._check_complete_value(wrapped) def test_raises_exception_for_non_boolean_value(self) -> None: # NOTE: Pass a non-bool value to verify the runtime guard against a misimplemented # Task.complete() returning a non-boolean. The type ignore is intentional. worker = Worker(scheduler=scheduler.Scheduler()) with pytest.raises(Exception, match='Return value of Task.complete'): worker._check_complete_value('not a bool') # type: ignore[arg-type] ================================================ FILE: test/test_zoned_date_second_parameter.py ================================================ import datetime import unittest from luigi.cmdline_parser import CmdlineParser from gokart import TaskOnKart, ZonedDateSecondParameter class ZonedDateSecondParameterTaskWithoutDefault(TaskOnKart[datetime.datetime]): task_namespace = __name__ dt: ZonedDateSecondParameter = ZonedDateSecondParameter() def run(self): self.dump(self.dt) class ZonedDateSecondParameterTaskWithDefault(TaskOnKart[datetime.datetime]): task_namespace = __name__ dt: ZonedDateSecondParameter = ZonedDateSecondParameter( default=datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9))) ) def run(self): self.dump(self.dt) class ZonedDateSecondParameterTest(unittest.TestCase): def setUp(self): self.default_datetime = datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9))) self.default_datetime_str = '2025-02-21T12:00:00+09:00' def test_default(self): with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault']) as cp: assert cp.get_task_obj().dt == self.default_datetime def test_parse_param_with_tz_suffix(self): with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00+09:00']) as cp: assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9))) def test_parse_param_with_Z_suffix(self): with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00Z']) as cp: assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=0))) def test_parse_param_without_timezone_input(self): with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithoutDefault', '--dt', '2025-02-21T12:00:00']) as cp: assert cp.get_task_obj().dt == datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=None) def test_parse_method(self): actual = ZonedDateSecondParameter().parse(self.default_datetime_str) expected = self.default_datetime self.assertEqual(actual, expected) def test_serialize_task(self): task = ZonedDateSecondParameterTaskWithoutDefault(dt=self.default_datetime) actual = str(task) expected = f'(dt={self.default_datetime_str})' self.assertTrue(actual.endswith(expected)) if __name__ == '__main__': unittest.main() ================================================ FILE: test/testing/__init__.py ================================================ ================================================ FILE: test/testing/test_pandas_assert.py ================================================ import unittest import pandas as pd import gokart class TestPandasAssert(unittest.TestCase): def test_assert_frame_contents_equal(self): expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2]) gokart.testing.assert_frame_contents_equal(resulted, expected) def test_assert_frame_contents_equal_with_small_error(self): expected = pd.DataFrame(data=dict(f1=[1.0001, 2.0001, 3.0001], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2.0002, 1.0002, 3.0002], f3=[222, 111, 333]), index=[1, 0, 2]) gokart.testing.assert_frame_contents_equal(resulted, expected, atol=1e-1) def test_assert_frame_contents_equal_with_duplicated_columns(self): expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) expected.columns = ['f1', 'f1', 'f2'] resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2]) resulted.columns = ['f2', 'f1', 'f1'] with self.assertRaises(AssertionError): gokart.testing.assert_frame_contents_equal(resulted, expected) def test_assert_frame_contents_equal_with_duplicated_indexes(self): expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2]) expected.index = [0, 1, 1] resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2]) expected.index = [1, 0, 1] with self.assertRaises(AssertionError): gokart.testing.assert_frame_contents_equal(resulted, expected) ================================================ FILE: test/tree/__init__.py ================================================ ================================================ FILE: test/tree/test_task_info.py ================================================ from __future__ import annotations import unittest from typing import Any from unittest.mock import patch import luigi import luigi.mock from luigi.mock import MockFileSystem, MockTarget import gokart from gokart.tree.task_info import dump_task_info_table, dump_task_info_tree, make_task_info_as_tree_str, make_task_info_tree class _SubTask(gokart.TaskOnKart[str]): task_namespace = __name__ param: luigi.IntParameter = luigi.IntParameter() def output(self): return self.make_target('sub_task.txt') def run(self): self.dump(f'task uid = {self.make_unique_id()}') class _Task(gokart.TaskOnKart[str]): task_namespace = __name__ param: luigi.IntParameter = luigi.IntParameter(default=10) sub: gokart.TaskInstanceParameter[_SubTask] = gokart.TaskInstanceParameter(default=_SubTask(param=20)) def requires(self): return self.sub def output(self): return self.make_target('task.txt') def run(self): self.dump(f'task uid = {self.make_unique_id()}') class _DoubleLoadSubTask(gokart.TaskOnKart[str]): task_namespace = __name__ sub1: gokart.TaskInstanceParameter[gokart.TaskOnKart[Any]] = gokart.TaskInstanceParameter() sub2: gokart.TaskInstanceParameter[gokart.TaskOnKart[Any]] = gokart.TaskInstanceParameter() def output(self): return self.make_target('sub_task.txt') def run(self): self.dump(f'task uid = {self.make_unique_id()}') class TestInfo(unittest.TestCase): def setUp(self) -> None: MockFileSystem().clear() luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False def tearDown(self) -> None: luigi.setup_logging.DaemonLogging._configured = False luigi.setup_logging.InterfaceLogging._configured = False @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_pending(self): task = _Task(param=1, sub=_SubTask(param=2)) # check before running tree = make_task_info_as_tree_str(task) expected = r""" └─-\(PENDING\) _Task\[[a-z0-9]*\] └─-\(PENDING\) _SubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_complete(self): task = _Task(param=1, sub=_SubTask(param=2)) # check after sub task runs gokart.build(task, reset_register=False) tree = make_task_info_as_tree_str(task) expected = r""" └─-\(COMPLETE\) _Task\[[a-z0-9]*\] └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_abbreviation(self): task = _DoubleLoadSubTask( sub1=_Task(param=1, sub=_SubTask(param=2)), sub2=_Task(param=1, sub=_SubTask(param=2)), ) # check after sub task runs gokart.build(task, reset_register=False) tree = make_task_info_as_tree_str(task) expected = r""" └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\] \|--\(COMPLETE\) _Task\[[a-z0-9]*\] \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\] └─-\(COMPLETE\) _Task\[[a-z0-9]*\] └─- \.\.\.$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_not_compress(self): task = _DoubleLoadSubTask( sub1=_Task(param=1, sub=_SubTask(param=2)), sub2=_Task(param=1, sub=_SubTask(param=2)), ) # check after sub task runs gokart.build(task, reset_register=False) tree = make_task_info_as_tree_str(task, abbr=False) expected = r""" └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\] \|--\(COMPLETE\) _Task\[[a-z0-9]*\] \| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\] └─-\(COMPLETE\) _Task\[[a-z0-9]*\] └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_not_compress_ignore_task(self): task = _DoubleLoadSubTask( sub1=_Task(param=1, sub=_SubTask(param=2)), sub2=_Task(param=1, sub=_SubTask(param=2)), ) # check after sub task runs gokart.build(task, reset_register=False) tree = make_task_info_as_tree_str(task, abbr=False, ignore_task_names=['_Task']) expected = r""" └─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]$""" self.assertRegex(tree, expected) @patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs)) def test_make_tree_info_with_cache(self): task = _DoubleLoadSubTask( sub1=_Task(param=1, sub=_SubTask(param=2)), sub2=_Task(param=1, sub=_SubTask(param=2)), ) # check child task_info is the same object tree = make_task_info_tree(task) self.assertTrue(tree.children_task_infos[0] is tree.children_task_infos[1]) class _TaskInfoExampleTaskA(gokart.TaskOnKart[Any]): task_namespace = __name__ class _TaskInfoExampleTaskB(gokart.TaskOnKart[Any]): task_namespace = __name__ class _TaskInfoExampleTaskC(gokart.TaskOnKart[str]): task_namespace = __name__ def requires(self): return dict(taskA=_TaskInfoExampleTaskA(), taskB=_TaskInfoExampleTaskB()) def run(self): self.dump('DONE') class TestTaskInfoTable(unittest.TestCase): def test_dump_task_info_table(self): with patch('gokart.target.SingleFileTarget.dump') as mock_obj: self.dumped_data: Any = None def _side_effect(obj, lock_at_dump): self.dumped_data = obj mock_obj.side_effect = _side_effect dump_task_info_table(task=_TaskInfoExampleTaskC(), task_info_dump_path='path.csv', ignore_task_names=['_TaskInfoExampleTaskB']) self.assertEqual(set(self.dumped_data['name']), {'_TaskInfoExampleTaskA', '_TaskInfoExampleTaskC'}) self.assertEqual( set(self.dumped_data.columns), {'name', 'unique_id', 'output_paths', 'params', 'processing_time', 'is_complete', 'task_log', 'requires'} ) class TestTaskInfoTree(unittest.TestCase): def test_dump_task_info_tree(self): with patch('gokart.target.SingleFileTarget.dump') as mock_obj: self.dumped_data: Any = None def _side_effect(obj, lock_at_dump): self.dumped_data = obj mock_obj.side_effect = _side_effect dump_task_info_tree(task=_TaskInfoExampleTaskC(), task_info_dump_path='path.pkl', ignore_task_names=['_TaskInfoExampleTaskB']) self.assertEqual(self.dumped_data.name, '_TaskInfoExampleTaskC') self.assertEqual(self.dumped_data.children_task_infos[0].name, '_TaskInfoExampleTaskA') self.assertEqual(self.dumped_data.requires.keys(), {'taskA', 'taskB'}) self.assertEqual(self.dumped_data.requires['taskA'].name, '_TaskInfoExampleTaskA') self.assertEqual(self.dumped_data.requires['taskB'].name, '_TaskInfoExampleTaskB') def test_dump_task_info_tree_with_invalid_path_extention(self): with patch('gokart.target.SingleFileTarget.dump') as mock_obj: self.dumped_data = None def _side_effect(obj, lock_at_dump): self.dumped_data = obj mock_obj.side_effect = _side_effect with self.assertRaises(AssertionError): dump_task_info_tree(task=_TaskInfoExampleTaskC(), task_info_dump_path='path.csv', ignore_task_names=['_TaskInfoExampleTaskB']) ================================================ FILE: test/tree/test_task_info_formatter.py ================================================ import unittest from typing import Any import gokart from gokart.tree.task_info_formatter import RequiredTask, _make_requires_info class _RequiredTaskExampleTaskA(gokart.TaskOnKart[Any]): task_namespace = __name__ class TestMakeRequiresInfo(unittest.TestCase): def test_make_requires_info_with_task_on_kart(self): requires = _RequiredTaskExampleTaskA() resulted = _make_requires_info(requires=requires) expected = RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id()) self.assertEqual(resulted, expected) def test_make_requires_info_with_list(self): requires = [_RequiredTaskExampleTaskA()] resulted = _make_requires_info(requires=requires) expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in requires] self.assertEqual(resulted, expected) def test_make_requires_info_with_generator(self): def _requires_gen(): return (_RequiredTaskExampleTaskA() for _ in range(2)) resulted = _make_requires_info(requires=_requires_gen()) expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in _requires_gen()] self.assertEqual(resulted, expected) def test_make_requires_info_with_dict(self): requires = dict(taskA=_RequiredTaskExampleTaskA()) resulted = _make_requires_info(requires=requires) expected = {key: RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for key, require in requires.items()} self.assertEqual(resulted, expected) def test_make_requires_info_with_invalid(self): requires = [1, 2] with self.assertRaises(TypeError): _make_requires_info(requires=requires) ================================================ FILE: test/util.py ================================================ import os import uuid # TODO: use pytest.fixture to share this functionality with other tests def _get_temporary_directory(): _uuid = str(uuid.uuid4()) return os.path.abspath(os.path.join(os.path.dirname(__name__), f'temporary-{_uuid}')) ================================================ FILE: tox.ini ================================================ [tox] envlist = py{310,311,312,313,314},ruff,mypy skipsdist = True [testenv] runner = uv-venv-lock-runner dependency_groups = test commands = {envpython} -m pytest --cov=gokart --cov-report=xml -vv {posargs:} [testenv:ruff] dependency_groups = lint commands = ruff check {posargs:} ruff format --check {posargs:} [testenv:mypy] dependency_groups = lint commands = mypy gokart test {posargs:}