Repository: m3dev/gokart
Branch: master
Commit: 0d0609000123
Files: 125
Total size: 466.8 KB
Directory structure:
gitextract_mhb_xs2d/
├── .github/
│ ├── CODEOWNERS
│ └── workflows/
│ ├── format.yml
│ ├── publish.yml
│ └── test.yml
├── .gitignore
├── .readthedocs.yaml
├── LICENSE
├── README.md
├── docs/
│ ├── Makefile
│ ├── conf.py
│ ├── efficient_run_on_multi_workers.rst
│ ├── for_pandas.rst
│ ├── gokart.rst
│ ├── index.rst
│ ├── intro_to_gokart.rst
│ ├── logging.rst
│ ├── make.bat
│ ├── mypy_plugin.rst
│ ├── polars.rst
│ ├── requirements.txt
│ ├── setting_task_parameters.rst
│ ├── slack_notification.rst
│ ├── task_information.rst
│ ├── task_on_kart.rst
│ ├── task_parameters.rst
│ ├── task_settings.rst
│ ├── tutorial.rst
│ └── using_task_task_conflict_prevention_lock.rst
├── examples/
│ ├── gokart_notebook_example.ipynb
│ ├── logging.ini
│ └── param.ini
├── gokart/
│ ├── __init__.py
│ ├── build.py
│ ├── config_params.py
│ ├── conflict_prevention_lock/
│ │ ├── task_lock.py
│ │ └── task_lock_wrappers.py
│ ├── errors/
│ │ └── __init__.py
│ ├── file_processor/
│ │ ├── __init__.py
│ │ ├── base.py
│ │ ├── pandas.py
│ │ └── polars.py
│ ├── file_processor.py
│ ├── gcs_config.py
│ ├── gcs_obj_metadata_client.py
│ ├── gcs_zip_client.py
│ ├── in_memory/
│ │ ├── __init__.py
│ │ ├── data.py
│ │ ├── repository.py
│ │ └── target.py
│ ├── info.py
│ ├── mypy.py
│ ├── object_storage.py
│ ├── pandas_type_config.py
│ ├── parameter.py
│ ├── py.typed
│ ├── required_task_output.py
│ ├── run.py
│ ├── s3_config.py
│ ├── s3_zip_client.py
│ ├── slack/
│ │ ├── __init__.py
│ │ ├── event_aggregator.py
│ │ ├── slack_api.py
│ │ └── slack_config.py
│ ├── target.py
│ ├── task.py
│ ├── task_complete_check.py
│ ├── testing/
│ │ ├── __init__.py
│ │ ├── check_if_run_with_empty_data_frame.py
│ │ └── pandas_assert.py
│ ├── tree/
│ │ ├── task_info.py
│ │ └── task_info_formatter.py
│ ├── utils.py
│ ├── worker.py
│ ├── workspace_management.py
│ ├── zip_client.py
│ └── zip_client_util.py
├── luigi.cfg
├── pyproject.toml
├── test/
│ ├── __init__.py
│ ├── config/
│ │ ├── __init__.py
│ │ ├── pyproject.toml
│ │ ├── pyproject_disallow_missing_parameters.toml
│ │ └── test_config.ini
│ ├── conflict_prevention_lock/
│ │ ├── __init__.py
│ │ ├── test_task_lock.py
│ │ └── test_task_lock_wrappers.py
│ ├── file_processor/
│ │ ├── __init__.py
│ │ ├── test_base.py
│ │ ├── test_factory.py
│ │ ├── test_pandas.py
│ │ └── test_polars.py
│ ├── in_memory/
│ │ ├── test_in_memory_target.py
│ │ └── test_repository.py
│ ├── slack/
│ │ ├── __init__.py
│ │ └── test_slack_api.py
│ ├── test_build.py
│ ├── test_cache_unique_id.py
│ ├── test_config_params.py
│ ├── test_explicit_bool_parameter.py
│ ├── test_gcs_config.py
│ ├── test_gcs_obj_metadata_client.py
│ ├── test_info.py
│ ├── test_large_data_fram_processor.py
│ ├── test_list_task_instance_parameter.py
│ ├── test_mypy.py
│ ├── test_pandas_type_check_framework.py
│ ├── test_pandas_type_config.py
│ ├── test_restore_task_by_id.py
│ ├── test_run.py
│ ├── test_s3_config.py
│ ├── test_s3_zip_client.py
│ ├── test_serializable_parameter.py
│ ├── test_target.py
│ ├── test_task_instance_parameter.py
│ ├── test_task_on_kart.py
│ ├── test_utils.py
│ ├── test_worker.py
│ ├── test_zoned_date_second_parameter.py
│ ├── testing/
│ │ ├── __init__.py
│ │ └── test_pandas_assert.py
│ ├── tree/
│ │ ├── __init__.py
│ │ ├── test_task_info.py
│ │ └── test_task_info_formatter.py
│ └── util.py
└── tox.ini
================================================
FILE CONTENTS
================================================
================================================
FILE: .github/CODEOWNERS
================================================
* @Hi-king @yokomotod @hirosassa @mski-iksm @kitagry @ujiuji1259 @mamo3gr @hiro-o918
================================================
FILE: .github/workflows/format.yml
================================================
name: Lint
on:
push:
branches: [ master ]
pull_request:
jobs:
formatting-check:
name: Lint
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up the latest version of uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install dependencies
run: |
uv tool install --python-preference only-managed --python 3.13 tox --with tox-uv
- name: Run ruff and mypy
run: |
uvx --with tox-uv tox run -e ruff,mypy
================================================
FILE: .github/workflows/publish.yml
================================================
name: Publish
on:
push:
tags: '*'
jobs:
deploy:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v6
- name: Set up the latest version of uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Build and publish
env:
UV_PUBLISH_TOKEN: ${{ secrets.PYPI_API_TOKEN }}
run: |
uv build
uv publish
================================================
FILE: .github/workflows/test.yml
================================================
name: Test
on:
push:
branches: [ master ]
pull_request:
jobs:
tests:
runs-on: ${{ matrix.platform }}
strategy:
max-parallel: 7
matrix:
platform: ["ubuntu-latest"]
tox-env: ["py310", "py311", "py312", "py313", "py314"]
include:
- platform: macos-15
tox-env: "py313"
- platform: macos-latest
tox-env: "py313"
steps:
- uses: actions/checkout@v6
- name: Set up the latest version of uv
uses: astral-sh/setup-uv@v7
with:
enable-cache: true
- name: Install dependencies
run: |
uv tool install --python-preference only-managed --python 3.13 tox --with tox-uv
- name: Test with tox
run: uvx --with tox-uv tox run -e ${{ matrix.tox-env }}
================================================
FILE: .gitignore
================================================
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# pyenv
.python-version
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
# pycharm
.idea
# gokart
resources
examples/resources
# poetry
dist
# temporary data
temporary.zip
================================================
FILE: .readthedocs.yaml
================================================
# Read the Docs configuration file for Sphinx projects
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
# Required
version: 2
# Set the OS, Python version and other tools you might need
build:
os: ubuntu-24.04
tools:
python: "3.12"
# Build from the docs/ directory with Sphinx
sphinx:
configuration: docs/conf.py
# Optional but recommended, declare the Python requirements required
# to build your documentation
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
python:
install:
- requirements: docs/requirements.txt
================================================
FILE: LICENSE
================================================
MIT License
Copyright (c) 2018 M3, Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
================================================
FILE: README.md
================================================
# gokart
[](https://github.com/m3dev/gokart/actions?query=workflow%3ATest)
[](https://gokart.readthedocs.io/en/latest/)
[](https://pypi.org/project/gokart/)
[](https://pypi.org/project/gokart/)

Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline.
[Documentation](https://gokart.readthedocs.io/en/latest/) for the latest release is hosted on readthedocs.
# About gokart
Here are some good things about gokart.
- The following meta data for each Task is stored separately in a `pkl` file with hash value
- task output data
- imported all module versions
- task processing time
- random seed in task
- displayed log
- all parameters set as class variables in the task
- Automatically rerun the pipeline if parameters of Tasks are changed.
- Support GCS and S3 as a data store for intermediate results of Tasks in the pipeline.
- The above output is exchanged between tasks as an intermediate file, which is memory-friendly
- `pandas.DataFrame` type and column checking during I/O
- Directory structure of saved files is automatically determined from structure of script
- Seeds for numpy and random are automatically fixed
- Can code while adhering to [SOLID](https://en.wikipedia.org/wiki/SOLID) principles as much as possible
- Tasks are locked via redis even if they run in parallel
**All the functions above are created for constructing Machine Learning batches. Provides an excellent environment for reproducibility and team development.**
Here are some non-goal / downside of the gokart.
- Batch execution in parallel is supported, but parallel and concurrent execution of task in memory.
- Gokart is focused on reproducibility. So, I/O and capacity of data storage can become a bottleneck.
- No support for task visualize.
- Gokart is not an experiment management tool. The management of the execution result is cut out as [Thunderbolt](https://github.com/m3dev/thunderbolt).
- Gokart does not recommend writing pipelines in toml, yaml, json, and more. Gokart is preferring to write them in Python.
# Getting Started
Within the activated Python environment, use the following command to install gokart.
```
pip install gokart
```
# Quickstart
## Minimal Example
A minimal gokart tasks looks something like this:
```python
import gokart
class Example(gokart.TaskOnKart):
def run(self):
self.dump('Hello, world!')
task = Example()
output = gokart.build(task)
print(output)
```
`gokart.build` return the result of dump by `gokart.TaskOnKart`. The example will output the following.
```
Hello, world!
```
## Type-Safe Pipeline Example
We introduce type-annotations to make a gokart pipeline robust.
Please check the following example to see how to use type-annotations on gokart.
Before using this feature, ensure to enable [mypy plugin](https://gokart.readthedocs.io/en/latest/mypy_plugin.html) feature in your project.
```python
import gokart
# `gokart.TaskOnKart[str]` means that the task dumps `str`
class StrDumpTask(gokart.TaskOnKart[str]):
def run(self):
self.dump('Hello, world!')
# `gokart.TaskOnKart[int]` means that the task dumps `int`
class OneDumpTask(gokart.TaskOnKart[int]):
def run(self):
self.dump(1)
# `gokart.TaskOnKart[int]` means that the task dumps `int`
class TwoDumpTask(gokart.TaskOnKart[int]):
def run(self):
self.dump(2)
class AddTask(gokart.TaskOnKart[int]):
# `a` requires a task to dump `int`
a: gokart.TaskInstanceParameter[gokart.TaskOnKart[int]] = gokart.TaskInstanceParameter()
# `b` requires a task to dump `int`
b: gokart.TaskInstanceParameter[gokart.TaskOnKart[int]] = gokart.TaskInstanceParameter()
def requires(self):
return dict(a=self.a, b=self.b)
def run(self):
# loading by instance parameter,
# `a` and `b` are treated as `int`
# because they are declared as `gokart.TaskOnKart[int]`
a = self.load(self.a)
b = self.load(self.b)
self.dump(a + b)
valid_task = AddTask(a=OneDumpTask(), b=TwoDumpTask())
# the next line will show type error by mypy
# because `StrDumpTask` dumps `str` and `AddTask` requires `int`
invalid_task = AddTask(a=OneDumpTask(), b=StrDumpTask())
```
This is an introduction to some of the gokart.
There are still more useful features.
Please See [Documentation](https://gokart.readthedocs.io/en/latest/) .
Have a good gokart life.
# Achievements
Gokart is a proven product.
- It's actually been used by [m3.inc](https://corporate.m3.com/en) for over 3 years
- Natural Language Processing Competition by [Nishika.inc](https://nishika.com) 2nd prize : [Solution Repository](https://github.com/vaaaaanquish/nishika_akutagawa_2nd_prize)
# Thanks
gokart is a wrapper for luigi. Thanks to luigi and dependent projects!
- [luigi](https://github.com/spotify/luigi)
================================================
FILE: docs/Makefile
================================================
# Minimal makefile for Sphinx documentation
#
# You can set these variables from the command line.
SPHINXOPTS =
SPHINXBUILD = sphinx-build
SOURCEDIR = .
BUILDDIR = _build
# Put it first so that "make" without argument is like "make help".
help:
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
.PHONY: help Makefile
# Catch-all target: route all unknown targets to Sphinx using the new
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
%: Makefile
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
================================================
FILE: docs/conf.py
================================================
# https://github.com/sphinx-doc/sphinx/issues/6211
import luigi
import gokart
luigi.task.Task.requires.__doc__ = gokart.task.TaskOnKart.requires.__doc__
luigi.task.Task.output.__doc__ = gokart.task.TaskOnKart.output.__doc__
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
# import os
# import sys
# sys.path.insert(0, os.path.abspath('../gokart/'))
# -- Project information -----------------------------------------------------
project = 'gokart'
copyright = '2019, Masahiro Nishiba'
author = 'Masahiro Nishiba'
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = ''
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = ['sphinx.ext.autodoc', 'sphinx.ext.viewcode']
# Add any paths that contain templates here, relative to this directory.
templates_path = ['_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.rst'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = ['_build', 'Thumbs.db', '.DS_Store']
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = None
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = []
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'gokartdoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'gokart.tex', 'gokart Documentation', 'Masahiro Nishiba', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [(master_doc, 'gokart', 'gokart Documentation', [author], 1)]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'gokart', 'gokart Documentation', author, 'gokart', 'One line description of project.', 'Miscellaneous'),
]
# -- Options for Epub output -------------------------------------------------
# Bibliographic Dublin Core info.
epub_title = project
# The unique identifier of the text. This can be a ISBN number
# or the project homepage.
#
# epub_identifier = ''
# A unique identification for the text.
#
# epub_uid = ''
# A list of files that should not be packed into the epub file.
epub_exclude_files = ['search.html']
================================================
FILE: docs/efficient_run_on_multi_workers.rst
================================================
How to improve efficiency when running on multiple workers
===========================================================
If multiple worker nodes are running similar gokart pipelines in parallel, it is possible that the exact same task may be executed by multiple workers.
(For example, when training multiple machine learning models with different parameters, the feature creation task in the first stage is expected to be exactly the same.)
It is inefficient to execute the same task on each of multiple worker nodes, so we want to avoid this.
Here we introduce `should_lock_run` feature to improve this inefficiency.
Suppress run() of the same task with `should_lock_run`
------------------------------------------------------
When `gokart.TaskOnKart.should_lock_run` is set to True, the task will fail if the same task is run()-ing by another worker.
By failing the task, other tasks that can be executed at that time are given priority.
After that, the failed task is automatically re-executed.
.. code:: python
class SampleTask2(gokart.TaskOnKart):
should_lock_run = True
Additional Option
------------------
Skip completed tasks with `complete_check_at_run`
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
By setting `gokart.TaskOnKart.complete_check_at_run` to True, the existence of the cache can be rechecked at run() time.
Default is True, but if the check takes too much time, you can set to False to inactivate the check.
.. code:: python
class SampleTask1(gokart.TaskOnKart):
complete_check_at_run = False
================================================
FILE: docs/for_pandas.rst
================================================
For Pandas
==========
Gokart has several features for Pandas.
Pandas Type Config
------------------
Pandas has a feature that converts the type of column(s) automatically. This feature sometimes cause wrong result. To avoid unintentional type conversion of pandas, we can specify a column name to check the type of Task input and output in gokart.
.. code:: python
from typing import Any, Dict
import pandas as pd
import gokart
# Please define a class which inherits `gokart.PandasTypeConfig`.
class SamplePandasTypeConfig(gokart.PandasTypeConfig):
@classmethod
def type_dict(cls) -> Dict[str, Any]:
return {'int_column': int}
class SampleTask(gokart.TaskOnKart[pd.DataFrame]):
def run(self):
# [PandasTypeError] because expected type is `int`, but `str` is passed.
df = pd.DataFrame(dict(int_column=['a']))
self.dump(df)
This is useful when dataframe has nullable columns because pandas auto-conversion often fails in such case.
Easy to Load DataFrame
----------------------
The :func:`~gokart.task.TaskOnKart.load` method is used to load input ``pandas.DataFrame``.
.. code:: python
def requires(self):
return MakeDataFrameTask()
def run(self):
df = self.load()
Please refer to :func:`~gokart.task.TaskOnKart.load`.
Fail on empty DataFrame
-----------------------
When the :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump` parameter is true, the :func:`~gokart.task.TaskOnKart.dump()` method raises :class:`~gokart.errors.EmptyDumpError` on trying to dump empty ``pandas.DataFrame``.
.. code:: python
import gokart
class EmptyTask(gokart.TaskOnKart):
def run(self):
df = pd.DataFrame()
self.dump(df)
::
$ python main.py EmptyTask --fail-on-empty-dump true
# EmptyDumpError
$ python main.py EmptyTask
# Task will be ran and outputs an empty dataframe
Empty caches sometimes hide bugs and let us spend much time debugging. This feature notifies us some bugs (including wrong datasources) in the early stage.
Please refer to :attr:`~gokart.task.TaskOnKart.fail_on_empty_dump`.
================================================
FILE: docs/gokart.rst
================================================
gokart package
==============
Submodules
----------
gokart.file\_processor module
-----------------------------
.. automodule:: gokart.file_processor
:members:
:undoc-members:
:show-inheritance:
gokart.info module
------------------
.. automodule:: gokart.info
:members:
:undoc-members:
:show-inheritance:
gokart.parameter module
-----------------------
.. automodule:: gokart.parameter
:members:
:undoc-members:
:show-inheritance:
gokart.run module
-----------------
.. automodule:: gokart.run
:members:
:undoc-members:
:show-inheritance:
gokart.s3\_config module
------------------------
.. automodule:: gokart.s3_config
:members:
:undoc-members:
:show-inheritance:
gokart.target module
--------------------
.. automodule:: gokart.target
:members:
:undoc-members:
:show-inheritance:
gokart.task module
------------------
.. automodule:: gokart.task
:members:
:undoc-members:
:show-inheritance:
gokart.workspace\_management module
-----------------------------------
.. automodule:: gokart.workspace_management
:members:
:undoc-members:
:show-inheritance:
gokart.zip\_client module
-------------------------
.. automodule:: gokart.zip_client
:members:
:undoc-members:
:show-inheritance:
Module contents
---------------
.. automodule:: gokart
:members:
:undoc-members:
:show-inheritance:
================================================
FILE: docs/index.rst
================================================
.. gokart documentation master file, created by
sphinx-quickstart on Fri Jan 11 07:59:25 2019.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to gokart's documentation!
==================================
Useful links: `GitHub `_ | `cookiecutter gokart `_
`Gokart `_ is a wrapper of the data pipeline library `luigi `_. Gokart solves "**reproducibility**", "**task dependencies**", "**constraints of good code**", and "**ease of use**" for Machine Learning Pipeline.
Good thing about gokart
-----------------------
Here are some good things about gokart.
- The following data for each Task is stored separately in a pkl file with hash value
- task output data
- imported all module versions
- task processing time
- random seed in task
- displayed log
- all parameters set as class variables in the task
- If change parameter of Task, rerun spontaneously.
- The above file will be generated with a different hash value
- The hash value of dependent task will also change and both will be rerun
- Support GCS or S3
- The above output is exchanged between tasks as an intermediate file, which is memory-friendly
- pandas.DataFrame type and column checking during I/O
- Directory structure of saved files is automatically determined from structure of script
- Seeds for numpy and random are automatically fixed
- Can code while adhering to SOLID principles as much as possible
- Tasks are locked via redis even if they run in parallel
**These are all functions baptized for creating Machine Learning batches. Provides an excellent environment for reproducibility and team development.**
Getting started
-----------------
.. toctree::
:maxdepth: 2
intro_to_gokart
tutorial
User Guide
-----------------
.. toctree::
:maxdepth: 2
task_on_kart
task_parameters
setting_task_parameters
task_settings
task_information
logging
slack_notification
using_task_task_conflict_prevention_lock
efficient_run_on_multi_workers
for_pandas
polars
mypy_plugin
API References
--------------
.. toctree::
:maxdepth: 2
gokart
Indices and tables
-------------------
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
================================================
FILE: docs/intro_to_gokart.rst
================================================
Intro To Gokart
===============
Installation
------------
Within the activated Python environment, use the following command to install gokart.
.. code:: sh
pip install gokart
Quickstart
----------
A minimal gokart tasks looks something like this:
.. code:: python
import gokart
class Example(gokart.TaskOnKart[str]):
def run(self):
self.dump('Hello, world!')
task = Example()
output = gokart.build(task)
print(output)
``gokart.build`` return the result of dump by ``gokart.TaskOnKart``. The example will output the following.
.. code:: sh
Hello, world!
``gokart`` records all the information needed for Machine Learning. By default, ``resources`` will be generated in the same directory as the script.
.. code:: sh
$ tree resources/
resources/
├── __main__
│ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
└── log
├── module_versions
│ └── Example_8441c59b5ce0113396d53509f19371fb.txt
├── processing_time
│ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
├── random_seed
│ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
├── task_log
│ └── Example_8441c59b5ce0113396d53509f19371fb.pkl
└── task_params
└── Example_8441c59b5ce0113396d53509f19371fb.pkl
The result of dumping the task will be saved in the ``__name__`` directory.
.. code:: python
import pickle
with open('resources/__main__/Example_8441c59b5ce0113396d53509f19371fb.pkl', 'rb') as f:
print(pickle.load(f)) # Hello, world!
That will be given hash value depending on the parameter of the task. This means that if you change the parameter of the task, the hash value will change, and change output file. This is very useful when changing parameters and experimenting. Please refer to :doc:`task_parameters` section for task parameters. Also see :doc:`task_on_kart` section for information on how to return this output destination.
In addition, the following files are automatically saved as ``log``.
- ``module_versions``: The versions of all modules that were imported when the script was executed. For reproducibility.
- ``processing_time``: The execution time of the task.
- ``random_seed``: This is random seed of python and numpy. For reproducibility in Machine Learning. Please refer to :doc:`task_settings` section.
- ``task_log``: This is the output of the task logger.
- ``task_params``: This is task's parameters. Please refer to :doc:`task_parameters` section.
How to running task
-------------------
Gokart has ``run`` and ``build`` methods for running task. Each has a different purpose.
- ``gokart.run``: uses arguments on the shell. return retcode.
- ``gokart.build``: uses inline code on jupyter notebook, IPython, and more. return task output.
.. note::
It is not recommended to use ``gokart.run`` and ``gokart.build`` together in the same script. Because ``gokart.build`` will clear the contents of ``luigi.register``. It's the only way to handle duplicate tasks.
gokart.run
~~~~~~~~~~
The :func:`~gokart.run` is running on shell.
.. code:: python
import gokart
import luigi
class SampleTask(gokart.TaskOnKart[str]):
param = luigi.Parameter()
def run(self):
self.dump(self.param)
gokart.run()
.. code:: sh
python sample.py SampleTask --local-scheduler --param=hello
If you were to write it in Python, it would be the same as the following behavior.
.. code:: python
gokart.run(['SampleTask', '--local-scheduler', '--param=hello'])
gokart.build
~~~~~~~~~~~~
The :func:`~gokart.build` is inline code.
.. code:: python
import gokart
import luigi
class SampleTask(gokart.TaskOnKart[str]):
param: luigi.Parameter = luigi.Parameter()
def run(self):
self.dump(self.param)
gokart.build(SampleTask(param='hello'), return_value=False)
To output logs of each tasks, you can pass `~log_level` parameter to `~gokart.build` as follows:
.. code:: python
gokart.build(SampleTask(param='hello'), return_value=False, log_level=logging.DEBUG)
This feature is very useful for running `~gokart` on jupyter notebook.
When some tasks are failed, gokart.build raises GokartBuildError. If you have to get tracebacks, you should set `log_level` as `logging.DEBUG`.
================================================
FILE: docs/logging.rst
================================================
Logging
=======
How to set up a common logger for gokart.
Core settings
-------------
Please write a configuration file similar to the following:
::
# base.ini
[core]
logging_conf_file=./conf/logging.ini
.. code:: python
import gokart
gokart.add_config('base.ini')
Logger ini file
---------------
It is the same as a general logging.ini file.
::
[loggers]
keys=root,luigi,luigi-interface,gokart,gokart.file_processor
[handlers]
keys=stderrHandler
[formatters]
keys=simpleFormatter
[logger_root]
level=INFO
handlers=stderrHandler
[logger_gokart]
level=INFO
handlers=stderrHandler
qualname=gokart
propagate=0
[logger_luigi]
level=INFO
handlers=stderrHandler
qualname=luigi
propagate=0
[logger_luigi-interface]
level=INFO
handlers=stderrHandler
qualname=luigi-interface
propagate=0
[logger_gokart.file_processor]
level=CRITICAL
handlers=stderrHandler
qualname=gokart.file_processor
[handler_stderrHandler]
class=StreamHandler
formatter=simpleFormatter
args=(sys.stdout,)
[formatter_simpleFormatter]
format=[%(asctime)s][%(name)s][%(levelname)s](%(filename)s:%(lineno)s) %(message)s
datefmt=%Y/%m/%d %H:%M:%S
Please refer to `Python logging documentation `_
================================================
FILE: docs/make.bat
================================================
@ECHO OFF
pushd %~dp0
REM Command file for Sphinx documentation
if "%SPHINXBUILD%" == "" (
set SPHINXBUILD=sphinx-build
)
set SOURCEDIR=.
set BUILDDIR=_build
if "%1" == "" goto help
%SPHINXBUILD% >NUL 2>NUL
if errorlevel 9009 (
echo.
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
echo.installed, then set the SPHINXBUILD environment variable to point
echo.to the full path of the 'sphinx-build' executable. Alternatively you
echo.may add the Sphinx directory to PATH.
echo.
echo.If you don't have Sphinx installed, grab it from
echo.http://sphinx-doc.org/
exit /b 1
)
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
goto end
:help
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS%
:end
popd
================================================
FILE: docs/mypy_plugin.rst
================================================
[Experimental] Mypy plugin
===========================
Mypy plugin provides type checking for gokart tasks using Mypy.
This feature is experimental.
How to use
--------------
Configure Mypy to use this plugin by adding the following to your ``mypy.ini`` file:
.. code:: ini
[mypy]
plugins = gokart.mypy:plugin
or by adding the following to your ``pyproject.toml`` file:
.. code:: toml
[tool.mypy]
plugins = ["gokart.mypy"]
Then, run Mypy as usual.
Examples
--------
For example the following code linted by Mypy:
.. code:: python
import gokart
import luigi
class Foo(gokart.TaskOnKart):
# NOTE: must all the parameters be annotated
foo: int = luigi.IntParameter(default=1)
bar: str = luigi.Parameter()
Foo(foo=1, bar='2') # OK
Foo(foo='1') # NG because foo is not int and bar is missing
Mypy plugin checks TaskOnKart generic types.
.. code:: python
class SampleTask(gokart.TaskOnKart):
str_task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter()
int_task: gokart.TaskOnKart[int] = gokart.TaskInstanceParameter()
def requires(self):
return dict(str=self.str_task, int=self.int_task)
def run(self):
s = self.load(self.str_task) # This type is inferred with "str"
i = self.load(self.int_task) # This type is inferred with "int"
SampleTask(
str_task=StrTask(), # mypy ok
int_task=StrTask(), # mypy error: Argument "int_task" to "StrTask" has incompatible type "StrTask"; expected "TaskOnKart[int]
)
Configurations (only pyproject.toml)
-----------------------------------
You can configure the Mypy plugin using the ``pyproject.toml`` file.
The following options are available:
.. code:: toml
[tool.gokart-mypy]
# If true, Mypy will raise an error if a task is missing required parameters.
# This configuration causes an error when the parameters set by `luigi.Config()`
# Default: false
disallow_missing_parameters = true
================================================
FILE: docs/polars.rst
================================================
Polars Support
==============
Gokart supports Polars DataFrames alongside pandas DataFrames for DataFrame-based file processors. This allows gradual migration from pandas to Polars or using both libraries simultaneously in your data pipelines.
Installation
------------
Polars support is optional. Install it with:
.. code:: bash
pip install gokart[polars]
Or install Polars separately:
.. code:: bash
pip install polars
Basic Usage
-----------
To use Polars DataFrames with gokart, specify ``dataframe_type='polars'`` when creating file processors:
.. code:: python
import polars as pl
from gokart import TaskOnKart
from gokart.file_processor import FeatherFileProcessor
class MyPolarsTask(TaskOnKart[pl.DataFrame]):
def output(self):
return self.make_target(
'path/to/target.feather',
processor=FeatherFileProcessor(
store_index_in_feather=False,
dataframe_type='polars'
)
)
def run(self):
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
self.dump(df)
Supported File Processors
--------------------------
The following file processors support the ``dataframe_type`` parameter:
CsvFileProcessor
^^^^^^^^^^^^^^^^
.. code:: python
from gokart.file_processor import CsvFileProcessor
# For Polars
processor = CsvFileProcessor(sep=',', encoding='utf-8', dataframe_type='polars')
# For pandas (default)
processor = CsvFileProcessor(sep=',', encoding='utf-8', dataframe_type='pandas')
# or simply
processor = CsvFileProcessor(sep=',', encoding='utf-8')
JsonFileProcessor
^^^^^^^^^^^^^^^^^
.. code:: python
from gokart.file_processor import JsonFileProcessor
# For Polars
processor = JsonFileProcessor(orient='records', dataframe_type='polars')
# For pandas (default)
processor = JsonFileProcessor(orient='records', dataframe_type='pandas')
ParquetFileProcessor
^^^^^^^^^^^^^^^^^^^^
.. code:: python
from gokart.file_processor import ParquetFileProcessor
# For Polars
processor = ParquetFileProcessor(
compression='gzip',
dataframe_type='polars'
)
# For pandas (default)
processor = ParquetFileProcessor(
compression='gzip',
dataframe_type='pandas'
)
FeatherFileProcessor
^^^^^^^^^^^^^^^^^^^^
.. code:: python
from gokart.file_processor import FeatherFileProcessor
# For Polars
processor = FeatherFileProcessor(
store_index_in_feather=False,
dataframe_type='polars'
)
# For pandas (default)
processor = FeatherFileProcessor(
store_index_in_feather=True,
dataframe_type='pandas'
)
.. note::
The ``store_index_in_feather`` parameter is pandas-specific and is ignored when using Polars.
Using Pandas and Polars Together
---------------------------------
Since projects often migrate from pandas gradually, gokart allows you to use both pandas and Polars simultaneously:
.. code:: python
import pandas as pd
import polars as pl
from gokart import TaskOnKart
from gokart.file_processor import FeatherFileProcessor
class PandasTask(TaskOnKart[pd.DataFrame]):
"""Task that outputs pandas DataFrame"""
def output(self):
return self.make_target(
'path/to/pandas_output.feather',
processor=FeatherFileProcessor(
store_index_in_feather=False,
dataframe_type='pandas'
)
)
def run(self):
df = pd.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
self.dump(df)
class PolarsTask(TaskOnKart[pl.DataFrame]):
"""Task that outputs Polars DataFrame"""
def requires(self):
return PandasTask()
def output(self):
return self.make_target(
'path/to/polars_output.feather',
processor=FeatherFileProcessor(
store_index_in_feather=False,
dataframe_type='polars'
)
)
def run(self):
# Load pandas DataFrame and convert to Polars
pandas_df = self.load() # Returns pandas DataFrame
polars_df = pl.from_pandas(pandas_df)
# Process with Polars
result = polars_df.with_columns(
(pl.col('a') * 2).alias('a_doubled')
)
self.dump(result)
Default Behavior
----------------
When ``dataframe_type`` is not specified, file processors default to ``'pandas'`` for backward compatibility:
.. code:: python
# These are equivalent
processor = CsvFileProcessor(sep=',')
processor = CsvFileProcessor(sep=',', dataframe_type='pandas')
Important Notes
---------------
**File Format Compatibility**
Files created with Polars processors can be read by pandas processors and vice versa. The underlying file formats (CSV, JSON, Parquet, Feather) are library-agnostic.
**Pandas-specific Features**
Some pandas-specific features are not available with Polars:
- ``store_index_in_feather`` parameter in ``FeatherFileProcessor`` is ignored for Polars
- ``engine`` parameter in ``ParquetFileProcessor`` is ignored for Polars (uses Polars' default)
**Error Handling**
If you specify ``dataframe_type='polars'`` but Polars is not installed, you'll get an ``ImportError`` with installation instructions:
.. code:: text
ImportError: polars is required for dataframe_type='polars'. Install with: pip install polars
Migration Strategy
------------------
Recommended approach for migrating from pandas to Polars:
1. Install Polars: ``pip install gokart[polars]``
2. Create new tasks using ``dataframe_type='polars'``
3. Keep existing tasks with ``dataframe_type='pandas'`` or default behavior
4. Gradually migrate tasks as needed
5. Convert DataFrames between libraries using ``pl.from_pandas()`` and ``df.to_pandas()`` when necessary
================================================
FILE: docs/requirements.txt
================================================
Sphinx
gokart
sphinx-rtd-theme
================================================
FILE: docs/setting_task_parameters.rst
================================================
============================
Setting Task Parameters
============================
There are several ways to set task parameters.
- Set parameter from command line
- Set parameter at config file
- Set parameter at upstream task
- Inherit parameter from other task
Set parameter from command line
==================================
.. code:: sh
python main.py sample.SomeTask --SomeTask-param=Hello
Parameter of each task can be set as a command line parameter in ``--[task name]-[parameter name]=[value]`` format.
Set parameter at config file
==================================
::
[sample.SomeTask]
param = Hello
Above config file (``config.ini``) must be read before ``gokart.run()`` as the following code:
.. code:: python
if __name__ == '__main__':
gokart.add_config('./conf/config.ini')
gokart.run()
It can also be loaded from environment variable as the following code:
::
[sample.SomeTask]
param=${PARAMS}
[TaskOnKart]
workspace_directory=${WORKSPACE_DIRECTORY}
The advantages of using environment variables are 1) important information will not be logged 2) common settings can be used.
Set parameter at upstream task
==================================
Parameters can be set at the upstream task, as in a typical pipeline.
.. code:: python
class UpstreamTask(gokart.TaskOnKart):
def requires(self):
return dict(sometask=SomeTask(param='Hello'))
Inherit parameter from other task
==================================
Parameter values can be inherited from other task using ``@inherits_config_params`` decorator.
.. code:: python
class MasterConfig(luigi.Config):
param: luigi.Parameter = luigi.Parameter()
param2: luigi.Parameter = luigi.Parameter()
@inherits_config_params(MasterConfig)
class SomeTask(gokart.TaskOnKart):
param: luigi.Parameter = luigi.Parameter()
This is useful when multiple tasks has the same parameter. In the above example, parameter settings of ``MasterConfig`` will be inherited to all tasks decorated with ``@inherits_config_params(MasterConfig)`` as ``SomeTask``.
Note that only parameters which exist in both ``MasterConfig`` and ``SomeTask`` will be inherited.
In the above example, ``param2`` will not be available in ``SomeTask``, since ``SomeTask`` does not have ``param2`` parameter.
.. code:: python
class MasterConfig(luigi.Config):
param: luigi.Parameter = luigi.Parameter()
param2: luigi.Parameter = luigi.Parameter()
@inherits_config_params(MasterConfig, parameter_alias={'param2': 'param3'})
class SomeTask(gokart.TaskOnKart):
param3: luigi.Parameter = luigi.Parameter()
You may also set a parameter name alias by setting ``parameter_alias``.
``parameter_alias`` must be a dictionary of key: inheriting task's parameter name, value: decorating task's parameter name.
In the above example, ``SomeTask.param3`` will be set to same value as ``MasterConfig.param2``.
================================================
FILE: docs/slack_notification.rst
================================================
Slack notification
=========================
Prerequisites
-------------
Prepare following environmental variables:
.. code:: sh
export SLACK_TOKEN=xoxb-your-token // should use token starts with "xoxb-" (bot token is preferable)
export SLACK_CHANNEL=channel-name // not "#channel-name", just "channel-name"
A Slack bot token can obtain from `slack app document `_.
A bot token needs following scopes:
- `channels:read`
- `chat:write`
- `files:write`
More about scopes are `slack scopes document `_.
Implement Slack notification
----------------------------
Write following codes pass arguments to your gokart workflow.
.. code:: python
cmdline_args = sys.argv[1:]
if 'SLACK_CHANNEL' in os.environ:
cmdline_args.append(f'--SlackConfig-channel={os.environ["SLACK_CHANNEL"]}')
if 'SLACK_TO_USER' in os.environ:
cmdline_args.append(f'--SlackConfig-to-user={os.environ["SLACK_TO_USER"]}')
gokart.run(cmdline_args)
================================================
FILE: docs/task_information.rst
================================================
Task Information
================
There are 6 ways to print the significant parameters and state of the task and its dependencies.
* 1. One is to use luigi module. See `luigi.tools.deps_tree module `_ for details.
* 2. ``task-info`` option of ``gokart.run()``.
* 3. ``make_task_info_as_tree_str()`` will return significant parameters and dependency tree as str.
* 4. ``make_task_info_as_table()`` will return significant parameter and dependent tasks as pandas.DataFrame table format.
* 5. ``dump_task_info_table()`` will dump the result of ``make_task_info_as_table()`` to a file.
* 6. ``dump_task_info_tree()`` will dump the task tree object (TaskInfo) to a pickle file.
This document will cover 2~6.
2. task-info option of gokart.run()
--------------------------------------------
On CLI
~~~~~~
An example implementation could be like:
.. code:: python
# main.py
import gokart
if __name__ == '__main__':
gokart.run()
.. code:: sh
$ python main.py \
TaskB \
--param=Hello \
--local-scheduler \
--tree-info-mode=all \
--tree-info-output-path=tree_all.txt
The ``--tree-info-mode`` option accepts "simple" and "all", and a task information is saved in ``--tree-info-output-path``.
when "simple" is passed, it outputs the states and the unique ids of tasks.
An example output is as follows:
.. code:: text
└─-(COMPLETE) TaskB[09fe5591ef2969ce7443c419a3b19e5d]
└─-(COMPLETE) TaskA[2549878535c070fb6c3cd4061bdbbcff]
When "all" is passed, it outputs the states, the unique ids, the significant parameters, the execution times and the task logs of tasks.
An example output is as follows:
.. code:: text
└─-(COMPLETE) TaskB[09fe5591ef2969ce7443c419a3b19e5d](parameter={'workspace_directory': './resources/', 'local_temporary_directory': './resources/tmp/', 'param': 'Hello'}, output=['./resources/output_of_task_b_09fe5591ef2969ce7443c419a3b19e5d.pkl'], time=0.002290010452270508s, task_log={})
└─-(COMPLETE) TaskA[2549878535c070fb6c3cd4061bdbbcff](parameter={'workspace_directory': './resources/', 'local_temporary_directory': './resources/tmp/', 'param': 'called by TaskB'}, output=['./resources/output_of_task_a_2549878535c070fb6c3cd4061bdbbcff.pkl'], time=0.0009829998016357422s, task_log={})
3. make_task_info_as_tree_str()
-----------------------------------------
``gokart.tree.task_info.make_task_info_as_tree_str()`` will return a tree dependency tree as a str.
.. code:: python
from gokart.tree.task_info import make_task_info_as_tree_str
make_task_info_as_tree_str(task, ignore_task_names)
# Parameters
# ----------
# - task: TaskOnKart
# Root task.
# - details: bool
# Whether or not to output details.
# - abbr: bool
# Whether or not to simplify tasks information that has already appeared.
# - ignore_task_names: Optional[List[str]]
# List of task names to ignore.
# Returns
# -------
# - tree_info : str
# Formatted task dependency tree.
example
.. code:: python
import luigi
import gokart
class TaskA(gokart.TaskOnKart[str]):
param = luigi.Parameter()
def run(self):
self.dump(f'{self.param}')
class TaskB(gokart.TaskOnKart[str]):
task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter()
def run(self):
task = self.load('task')
self.dump(task + ' taskB')
class TaskC(gokart.TaskOnKart[str]):
task: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter()
def run(self):
task = self.load('task')
self.dump(task + ' taskC')
class TaskD(gokart.TaskOnKart):
task1: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter()
task2: gokart.TaskOnKart[str] = gokart.TaskInstanceParameter()
def run(self):
task = [self.load('task1'), self.load('task2')]
self.dump(','.join(task))
.. code:: python
task = TaskD(
task1=TaskD(
task1=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))), # same task
task2=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))) # same task
),
task2=TaskD(
task1=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))), # same task
task2=TaskD(task1=TaskC(task=TaskA(param='foo')), task2=TaskC(task=TaskB(task=TaskA(param='bar')))) # same task
)
)
print(gokart.make_task_info_as_tree_str(task))
.. code:: sh
└─-(PENDING) TaskD[187ff82158671283e127e2e1f7c9c095]
|--(PENDING) TaskD[ca9e943ce049e992b371898c0578784e] # duplicated TaskD
| |--(PENDING) TaskD[1cc9f9fc54a56614f3adef74398684f4] # duplicated TaskD
| | |--(PENDING) TaskC[dce3d8e7acaf1bb9731fb4f2ae94e473]
| | | └─-(PENDING) TaskA[be65508b556dd3752359b4246791413d]
| | └─-(PENDING) TaskC[de39593d31490aba3cdca3c650432504]
| | └─-(PENDING) TaskB[bc2f7d6cdd6521cc116c35f0f144eed3]
| | └─-(PENDING) TaskA[5a824f7d232eb69d46f0ac6bbd93b565]
| └─-(PENDING) TaskD[1cc9f9fc54a56614f3adef74398684f4]
| └─- ...
└─-(PENDING) TaskD[ca9e943ce049e992b371898c0578784e]
└─- ...
In the above example, the sub-trees already shown is omitted.
This can be disabled by passing ``False`` to ``abbr`` flag:
.. code:: python
print(make_task_info_as_tree_str(task, abbr=False))
4. make_task_info_as_table()
--------------------------------
``gokart.tree.task_info.make_task_info_as_table()`` will return a table containing the information of significant parameters and dependent tasks as a pandas DataFrame.
This table contains `task name`, `cache unique id`, `cache file path`, `task parameters`, `task processing time`, `completed flag`, and `task log`.
.. code:: python
from gokart.tree.task_info import make_task_info_as_table
make_task_info_as_table(task, ignore_task_names)
# """Return a table containing information about dependent tasks.
#
# Parameters
# ----------
# - task: TaskOnKart
# Root task.
# - ignore_task_names: Optional[List[str]]
# List of task names to ignore.
# Returns
# -------
# - task_info_table : pandas.DataFrame
# Formatted task dependency table.
# """
5. dump_task_info_table()
-----------------------------------------
``gokart.tree.task_info.dump_task_info_table()`` will dump the task_info table made at ``make_task_info_as_table()`` to a file.
.. code:: python
from gokart.tree.task_info import dump_task_info_table
dump_task_info_table(task, task_info_dump_path, ignore_task_names)
# Parameters
# ----------
# - task: TaskOnKart
# Root task.
# - task_info_dump_path: str
# Output target file path. Path destination can be `local`, `S3`, or `GCS`.
# File extension can be any type that gokart file processor accepts, including `csv`, `pickle`, or `txt`.
# See `TaskOnKart.make_target module ` for details.
# - ignore_task_names: Optional[List[str]]
# List of task names to ignore.
# Returns
# -------
# None
6. dump_task_info_tree()
-----------------------------------------
``gokart.tree.task_info.dump_task_info_tree()`` will dump the task tree object (TaskInfo) to a pickle file.
.. code:: python
from gokart.tree.task_info import dump_task_info_tree
dump_task_info_tree(task, task_info_dump_path, ignore_task_names, use_unique_id)
# Parameters
# ----------
# - task: TaskOnKart
# Root task.
# - task_info_dump_path: str
# Output target file path. Path destination can be `local`, `S3`, or `GCS`.
# File extension must be '.pkl'.
# - ignore_task_names: Optional[List[str]]
# List of task names to ignore.
# - use_unique_id: bool = True
# Whether to use unique id to dump target file. Default is True.
# Returns
# -------
# None
Task Logs
---------
To output extra information of tasks by ``tree-info``, the member variable :attr:`~gokart.task.TaskOnKart.task_log` of ``TaskOnKart`` keeps any information as a dictionary.
For instance, the following code runs,
.. code:: python
import gokart
class SampleTaskLog(gokart.TaskOnKart):
def run(self):
# Add some logs.
self.task_log['sample key'] = 'sample value'
if __name__ == '__main__':
SampleTaskLog().run()
gokart.run([
'--tree-info-mode=all',
'--tree-info-output-path=sample_task_log.txt',
'SampleTaskLog',
'--local-scheduler'])
the output could be like:
.. code:: text
└─-(COMPLETE) SampleTaskLog[...](..., task_log={'sample key': 'sample value'})
Delete Unnecessary Output Files
--------------------------------
To delete output files which are not necessary to run a task, add option ``--delete-unnecessary-output-files``. This option is supported only when a task outputs files in local storage not S3 for now.
================================================
FILE: docs/task_on_kart.rst
================================================
TaskOnKart
==========
``TaskOnKart`` inherits ``luigi.Task``, and has functions to make it easy to define tasks.
Please see `luigi documentation `_ for details of ``luigi.Task``.
Please refer to :doc:`intro_to_gokart` section and :doc:`tutorial` section.
Outline
--------
How ``TaskOnKart`` helps to define a task looks like:
.. code:: python
import luigi
import gokart
class TaskA(gokart.TaskOnKart[str]):
param: luigi.Parameter = luigi.Parameter()
def output(self):
return self.make_target('output_of_task_a.pkl')
def run(self):
results = f'param={self.param}'
self.dump(results)
class TaskB(gokart.TaskOnKart[str]):
param: luigi.Parameter = luigi.Parameter()
def requires(self):
return TaskA(param='world')
def output(self):
# `make_target` makes an instance of `luigi.Target`.
# This infers the output format and the destination of an output objects.
# The target file path is
# '{self.workspace_directory}/output_of_task_b_{self.make_unique_id()}.pkl'.
return self.make_target('output_of_task_b.pkl')
def run(self):
# `load` loads input data. In this case, this loads the output of `TaskA`.
output_of_task_a = self.load()
results = f'Task A: {output_of_task_a}\nTaskB: param={self.param}'
# `dump` writes `results` to the file path of `self.output()`.
self.dump(results)
if __name__ == '__main__':
print(gokart.build([TaskB(param='Hello')]))
The result of this script will look like this
.. code:: sh
Task A: param=world
Task B: param=Hello
The results are obtained as a pipeline by linking A and B.
TaskOnKart.make_target
----------------------
The :func:`~gokart.task.TaskOnKart.make_target` method is used to make an instance of ``Luigi.Target``.
For instance, an example implementation could be as follows:
.. code:: python
def output(self):
return self.make_target('file_name.pkl')
The ``make_target`` method adds ``_{self.make_unique_id()}`` to the file name as suffix.
In this case, the target file path is ``{self.workspace_directory}/file_name_{self.make_unique_id()}.pkl``.
It is also possible to specify a file format other than pkl. The supported file formats are as follows:
- .pkl
- .txt
- .csv
- .tsv
- .gz
- .json
- .xml
- .npz
- .parquet
- .feather
- .png
- .jpg
- .ini
If dump something other than the above, can use :func:`~gokart.TaskOnKart.make_model_target`.
Please refer to :func:`~gokart.task.TaskOnKart.make_target` and described later Advanced Features section.
.. note::
By default, file path is inferred from "__name__" of the script, so ``output`` method can be omitted.
Please refer to :doc:`tutorial` section.
.. note::
When using `.feather`, index will be converted to column at saving and restored to index at loading.
If you don't prefere saving index, set `store_index_in_feather=False` parameter at `gokart.target.make_target()`.
.. note::
When you set `serialized_task_definition_check=True`, the task will rerun when you modify the scripts of the task.
Please note that the scripts outside the class are not considered.
TaskOnKart.load
----------------
The :func:`~gokart.task.TaskOnKart.load` method is used to load input data.
For instance, an example implementation could be as follows:
.. code:: python
def requires(self):
return TaskA(param='called by TaskB')
def run(self):
# `load` loads input data. In this case, this loads the output of `TaskA`.
output_of_task_a = self.load()
In the case that a task requires 2 or more tasks as input, the return value of this method has the same structure with `requires` value.
For instance, an example implementation that `requires` returns a dictionary of tasks could be like follows:
.. code:: python
def requires(self):
return dict(a=TaskA(), b=TaskB())
def run(self):
data = self.load() # returns dict(a=self.load('a'), b=self.load('b'))
The `load` method loads individual task input by passing a key of an input dictionary as follows:
.. code:: python
def run(self):
data_a = self.load('a')
data_b = self.load('b')
As an alternative, the `load` method loads individual task input by passing an instance of TaskOnKart as follows:
.. code:: python
def run(self):
data_a = self.load(TaskA())
data_b = self.load(TaskB())
We can also omit the :func:`~gokart.task.TaskOnKart.requires` and write the task used by :func:`~gokart.parameter.TaskInstanceParameter`.
Also please refer to :func:`~gokart.task.TaskOnKart.load`, :doc:`task_parameters`, and described later Advanced Features section.
TaskOnKart.dump
----------------
The :func:`~gokart.task.TaskOnKart.dump` method is used to dump results of tasks.
For instance, an example implementation could be as follows:
.. code:: python
def output(self):
return self.make_target('output.pkl')
def run(self):
results = do_something(self.load())
self.dump(results)
In the case that a task has 2 or more output, it is possible to specify output target by passing a key of dictionary like follows:
.. code:: python
def output(self):
return dict(a=self.make_target('output_a.pkl'), b=self.make_target('output_b.pkl'))
def run(self):
a_data = do_something_a(self.load())
b_data = do_something_b(self.load())
self.dump(a_data, 'a')
self.dump(b_data, 'b')
Please refer to :func:`~gokart.task.TaskOnKart.dump`.
Advanced Features
---------------------
TaskOnKart.load_generator
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The :func:`~gokart.task.TaskOnKart.load_generator` method is used to load input data with generator.
For instance, an example implementation could be as follows:
.. code:: python
def requires(self):
return TaskA(param='called by TaskB')
def run(self):
for data in self.load_generator():
any_process(data)
Usage is the same as `TaskOnKart.generator`.
`load_generator` reads the divided file into iterations.
It's effective when can't read all data to memory, because `load_generator` doesn't load all files at once.
Please refer to :func:`~gokart.task.TaskOnKart.load_generator`.
TaskOnKart.make_model_target
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
The :func:`~gokart.task.TaskOnKart.make_model_target` method is used to dump for non supported file types.
.. code:: python
import gensim
class TrainWord2Vec(gokart.TaskOnKart[Word2VecResult]):
def output(self):
# please use 'zip'.
return self.make_model_target(
'model.zip',
save_function=gensim.model.Word2Vec.save,
load_function=gensim.model.Word2Vec.load)
def run(self):
# -- train word2vec ---
word2vec = train_word2vec()
self.dump(word2vec)
It is dumped and zipped with ``gensim.model.Word2Vec.save``.
Please refer to :func:`~gokart.task.TaskOnKart.make_model_target`.
TaskOnKart.fail_on_empty_dump
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Please refer to :doc:`for_pandas`.
TaskOnKart.should_dump_supplementary_log_files
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. Default is True.
Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.
Dump csv with encoding
~~~~~~~~~~~~~~~~~~~~~~~
You can dump csv file by implementing `Task.output()` method as follows:
.. code:: python
def output(self):
return self.make_target('file_name.csv')
By default, csv file is dumped with `utf-8` encoding.
If you want to dump csv file with other encodings, you can use `encoding` parameter as follows:
.. code:: python
from gokart.file_processor import CsvFileProcessor
def output(self):
return self.make_target('file_name.csv', processor=CsvFileProcessor(encoding='cp932'))
# This will dump csv as 'cp932' which is used in Windows.
Cache output in memory instead of dumping to files
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
You can use :class:`~InMemoryTarget` to cache output in memory instead of dumping to files by calling :func:`~gokart.target.make_in_memory_target`.
Please note that :class:`~InMemoryTarget` is an experimental feature.
.. code:: python
from gokart.in_memory.target import make_in_memory_target
def output(self):
unique_id = self.make_unique_id() if use_unique_id else None
# TaskLock is not supported in InMemoryTarget, so it's dummy
task_lock_params = make_task_lock_params(
file_path='dummy_path',
unique_id=unique_id,
redis_host=None,
redis_port=None,
redis_timeout=self.redis_timeout,
raise_task_lock_exception_on_collision=False,
)
return make_in_memory_target('dummy_path', task_lock_params, unique_id)
================================================
FILE: docs/task_parameters.rst
================================================
=================
Task Parameters
=================
Luigi Parameter
================
We can set parameters for tasks.
Also please refer to :doc:`task_settings` section.
.. code:: python
class Task(gokart.TaskOnKart):
param_a: luigi.Parameter = luigi.Parameter()
param_c: luigi.ListParameter = luigi.ListParameter()
param_d: luigi.IntParameter = luigi.IntParameter(default=1)
Please refer to `luigi document `_ for a list of parameter types.
Gokart Parameter
================
There are also parameters provided by gokart.
- gokart.TaskInstanceParameter
- gokart.ListTaskInstanceParameter
- gokart.ExplicitBoolParameter
gokart.TaskInstanceParameter
--------------------------------
The :func:`~gokart.parameter.TaskInstanceParameter` executes a task using the results of a task as dynamic parameters.
.. code:: python
class TaskA(gokart.TaskOnKart[str]):
def run(self):
self.dump('Hello')
class TaskB(gokart.TaskOnKart[str]):
require_task: gokart.TaskInstanceParameter = gokart.TaskInstanceParameter()
def requires(self):
return self.require_task
def run(self):
task_a = self.load()
self.dump(','.join([task_a, 'world']))
task = TaskB(require_task=TaskA())
print(gokart.build(task)) # Hello,world
Helps to create a pipeline.
gokart.ListTaskInstanceParameter
-------------------------------------
The :func:`~gokart.parameter.ListTaskInstanceParameter` is list of TaskInstanceParameter.
gokart.ExplicitBoolParameter
-----------------------------------
The :func:`~gokart.parameter.ExplicitBoolParameter` is parameter for explicitly specified value.
``luigi.BoolParameter`` already has "explicit parsing" feature, but also still has implicit behavior like follows.
::
$ python main.py Task --param
# param will be set as True
$ python main.py Task
# param will be set as False
``ExplicitBoolParameter`` solves these problems on parameters from command line.
gokart.SerializableParameter
----------------------------
The :func:`~gokart.parameter.SerializableParameter` is a parameter for any object that can be serialized and deserialized.
This parameter is particularly useful when you want to pass a complex object or a set of parameters to a task.
The object must implement the following methods:
- ``gokart_serialize``: Serialize the object to a string. This serialized string must uniquely identify the object to enable task caching.
Note that it is not required for deserialization.
- ``gokart_deserialize``: Deserialize the object from a string, typically used for CLI arguments.
Example
^^^^^^^
.. code-block:: python
import json
from dataclasses import dataclass
import gokart
@dataclass(frozen=True)
class Config:
foo: int
# The `bar` field does not affect the result of the task.
# Similar to `luigi.Parameter(significant=False)`.
bar: str
def gokart_serialize(self) -> str:
# Serialize only the `foo` field since `bar` is irrelevant for caching.
return json.dumps({'foo': self.foo})
@classmethod
def gokart_deserialize(cls, s: str) -> 'Config':
# Deserialize the object from the provided string.
return cls(**json.loads(s))
class DummyTask(gokart.TaskOnKart):
config: gokart.SerializableParameter[Config] = gokart.SerializableParameter(object_type=Config)
def run(self):
# Save the `config` object as part of the task result.
self.dump(self.config)
================================================
FILE: docs/task_settings.rst
================================================
Task Settings
=============
Task settings. Also please refer to :doc:`task_parameters` section.
Directory to Save Outputs
-------------------------
We can use both a local directory and the S3 to save outputs.
If you would like to use local directory, please set a local directory path to :attr:`~gokart.task.TaskOnKart.workspace_directory`. Please refer to :doc:`task_parameters` for how to set it up.
It is recommended to use the config file since it does not change much.
::
# base.ini
[TaskOnKart]
workspace_directory=${TASK_WORKSPACE_DIRECTORY}
.. code:: python
# main.py
import gokart
gokart.add_config('base.ini')
To use the S3 or GCS repository, please set the bucket path as ``s3://{YOUR_REPOSITORY_NAME}`` or ``gs://{YOUR_REPOSITORY_NAME}`` respectively.
If use S3 or GCS, please set credential information to Environment Variables.
.. code:: sh
# S3
export AWS_ACCESS_KEY_ID='~~~' # AWS access key
export AWS_SECRET_ACCESS_KEY='~~~' # AWS secret access key
# GCS
export GCS_CREDENTIAL='~~~' # GCS credential
export DISCOVER_CACHE_LOCAL_PATH='~~~' # The local file path of discover api cache.
Rerun task
----------
There are times when we want to rerun a task, such as when change script or on batch. Please use the ``rerun`` parameter or add an arbitrary parameter.
When set rerun as follows:
.. code:: python
# rerun TaskA
gokart.build(Task(rerun=True))
When used from an argument as follows:
.. code:: python
# main.py
class Task(gokart.TaskOnKart[str]):
def run(self):
self.dump('hello')
.. code:: sh
python main.py Task --local-scheduler --rerun
``rerun`` parameter will look at the dependent tasks up to one level.
Example: Suppose we have a straight line pipeline composed of TaskA, TaskB and TaskC, and TaskC is an endpoint of this pipeline. We also suppose that all the tasks have already been executed.
- TaskA(rerun=True) -> TaskB -> TaskC # not rerunning
- TaskA -> TaskB(rerun=True) -> TaskC # rerunning TaskB and TaskC
This is due to the way intermediate files are handled. ``rerun`` parameter is ``significant=False``, it does not affect the hash value. It is very important to understand this difference.
If you want to change the parameter of TaskA and rerun TaskB and TaskC, recommend adding an arbitrary parameter.
.. code:: python
class TaskA(gokart.TaskOnKart):
__version: luigi.IntParameter = luigi.IntParameter(default=1)
If the hash value of TaskA will change, the dependent tasks (in this case, TaskB and TaskC) will rerun.
Fix random seed
---------------
Every task has a parameter named :attr:`~gokart.task.TaskOnKart.fix_random_seed_methods` and :attr:`~gokart.task.TaskOnKart.fix_random_seed_value`. This can be used to fix the random seed.
.. code:: python
import gokart
import random
import numpy
import torch
class Task(gokart.TaskOnKart[dict[str, Any]]):
def run(self):
x = [random.randint(0, 100) for _ in range(0, 10)]
y = [np.random.randint(0, 100) for _ in range(0, 10)]
z = [torch.randn(1).tolist()[0] for _ in range(0, 5)]
self.dump({'random': x, 'numpy': y, 'torch': z})
gokart.build(
Task(
fix_random_seed_methods=[
"random.seed",
"numpy.random.seed",
"torch.random.manual_seed"],
fix_random_seed_value=57))
::
# //--- The output is as follows every time. ---
# {'random': [65, 41, 61, 37, 55, 81, 48, 2, 94, 21],
# 'numpy': [79, 86, 5, 22, 79, 98, 56, 40, 81, 37], 'torch': []}
# 'torch': [0.14460121095180511, -0.11649507284164429,
# 0.6928958296775818, -0.916053831577301, 0.7317505478858948]}
This will be useful for using Machine Learning Libraries.
================================================
FILE: docs/tutorial.rst
================================================
Tutorial
========
Also please refer to :doc:`intro_to_gokart` section.
1, Make gokart project
----------------------
Create a project using `cookiecutter-gokart `_.
.. code:: sh
cookiecutter https://github.com/m3dev/cookiecutter-gokart
# project_name [project_name]: example
# package_name [package_name]: gokart_example
# python_version [3.7.0]:
# author [your name]: m3dev
# package_description [What's this project?]: gokart example
# license [MIT License]:
You will have a directory tree like following:
.. code:: sh
tree example/
example/
├── Dockerfile
├── README.md
├── conf
│ ├── logging.ini
│ └── param.ini
├── gokart_example
│ ├── __init__.py
│ ├── model
│ │ ├── __init__.py
│ │ └── sample.py
│ └── utils
│ └── template.py
├── main.py
├── pyproject.toml
└── test
├── __init__.py
└── unit_test
└── test_sample.py
2, Running sample task
----------------------
Let's run the first task.
.. code:: sh
python main.py gokart_example.Sample --local-scheduler
The results are stored in resources directory.
.. code:: sh
tree resources
resources/
├── gokart_example
│ └── model
│ └── sample
│ └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl
└── log
├── module_versions
│ └── Sample_cdf55a3d6c255d8c191f5f472da61f99.txt
├── processing_time
│ └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl
├── random_seed
│ └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl
├── task_log
│ └── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl
└── task_params
└── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl
Please refer to :doc:`intro_to_gokart` for output
.. note::
It is better to use poetry in terms of the module version. Please refer to `poetry document `_
.. code:: sh
poetry lock
poetry run python main.py gokart_example.Sample --local-scheduler
If want to stabilize it further, please use docker.
.. code:: sh
docker build -t sample .
docker run -it sample "python main.py gokart_example.Sample --local-scheduler"
3, Check result
---------------
Check the output.
.. code:: python
with open('resources/gokart_example/model/sample/Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl', 'rb') as f:
print(pickle.load(f)) # sample output
4, Run unittest
------------------
It is important to run unittest before and after modifying the code.
.. code:: sh
python -m unittest discover -s ./test/unit_test/
.
----------------------------------------------------------------------
Ran 1 test in 0.001s
OK
5, Create Task
--------------
Writing gokart-like tasks.
Modify ``example/gokart_example/model/sample.py`` as follows:
.. code:: python
from logging import getLogger
import gokart
from gokart_example.utils.template import GokartTask
logger = getLogger(__name__)
class Sample(GokartTask):
def run(self):
self.dump('sample output')
class StringToSplit(GokartTask):
"""Like the function to divide received data by spaces."""
task: gokart.TaskInstanceParameter = gokart.TaskInstanceParameter()
def run(self):
sample = self.load('task')
self.dump(sample.split(' '))
class Main(GokartTask):
"""Endpoint task."""
def requires(self):
return StringToSplit(task=Sample())
Added ``Main`` and ``StringToSplit``. ``StringToSplit`` is a function-like task that loads the result of an arbitrary task and splits it by spaces. ``Main`` is injecting ``Sample`` into ``StringToSplit``. It like Endpoint.
Let’s run the ``Main`` task.
.. code:: sh
python main.py gokart_example.Main --local-scheduler
Please take a look at the logger output at this time.
::
===== Luigi Execution Summary =====
Scheduled 3 tasks of which:
* 1 complete ones were encountered:
- 1 gokart_example.Sample(...)
* 2 ran successfully:
- 1 gokart_example.Main(...)
- 1 gokart_example.StringToSplit(...)
This progress looks :) because there were no failed tasks or missing dependencies
===== Luigi Execution Summary =====
As the log shows, ``Sample`` has been executed once, so the ``cache`` will be used.
The only things that worked were ``Main`` and ``StringToSplit``.
The output will look like the following, with the result in ``StringToSplit_b8a0ce6c972acbd77eae30f35da4307e.pkl``.
::
tree resources/
resources/
├── gokart_example
│ └── model
│ └── sample
│ ├── Sample_cdf55a3d6c255d8c191f5f472da61f99.pkl
│ └── StringToSplit_b8a0ce6c972acbd77eae30f35da4307e.pkl
...
.. code:: python
with open('resources/gokart_example/model/sample/StringToSplit_b8a0ce6c972acbd77eae30f35da4307e.pkl', 'rb') as f:
print(pickle.load(f)) # ['sample', 'output']
It was able to move the added task.
6, Rerun Task
-------------
Finally, let's rerun the task.
There are two ways to rerun a task.
Change the ``rerun parameter`` or ``parameters of the dependent tasks``.
``gokart.TaskOnKart`` can set ``rerun parameter`` for each task like following:
.. code:: python
class Main(GokartTask):
rerun=True
def requires(self):
return StringToSplit(task=Sample(rerun=True), rerun=True)
OR
Add new parameter on dependent tasks like following:
.. code:: python
class Sample(GokartTask):
version: luigi.IntParameter = luigi.IntParameter(default=1)
def run(self):
self.dump('sample output version {self.version}')
In both cases, all tasks will be rerun.
The difference is hash value given to output files.
The reurn parameter has no effect on the hash value.
So it will be rerun with the same hash value.
In the second method, ``version parameter`` is added to the ``Sample`` task.
This parameter will change the hash value of ``Sample`` and generate another output file.
And the dependent task, ``StringToSplit``, will also have a different hash value, and rerun.
Please refer to :doc:`task_settings` for details.
Please try rerunning task at hand:)
Feature
-------
This is the end of the gokart tutorial.
The tutorial is an introduction to some of the features.
There are still more useful features.
Please See :doc:`task_on_kart` section, :doc:`for_pandas` section and :doc:`task_parameters` section for more useful features of the task.
Have a good gokart life.
================================================
FILE: docs/using_task_task_conflict_prevention_lock.rst
================================================
Task conflict prevention lock
=========================
If there is a possibility of multiple worker nodes executing the same task, task cache conflict may happen.
Specifically, while node A is loading the cache of a task, node B may be writing to it.
This can lead to reading an inappropriate data and other unwanted behaviors.
The redis lock introduced in this page is a feature to prevent such cache collisions.
Requires
--------
You need to install `redis `_ for using this advanced feature.
How to use
-----------
1. Set up a redis server at somewhere accessible from gokart/luigi jobs.
e.g. Following script will run redis at your localhost.
.. code:: bash
$ redis-server
2. Set redis server hostname and port number as parameters of gokart.TaskOnKart().
You can set it by adding ``--redis-host=[your-redis-localhost] --redis-port=[redis-port-number]`` options to gokart python script.
e.g.
.. code:: bash
python main.py sample.SomeTask --local-scheduler --redis-host=localhost --redis-port=6379
Alternatively, you may set parameters at config file.
e.g.
.. code::
[TaskOnKart]
redis_host=localhost
redis_port=6379
3. Done
With the above configuration, all tasks that inherits gokart.TaskOnKart will ask the redis server if any other node is not trying to access the same cache file at the same time whenever they access the file with dump or load.
================================================
FILE: examples/gokart_notebook_example.ipynb
================================================
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"gokart 1.0.2\n"
]
}
],
"source": [
"!pip list | grep gokart"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import gokart\n",
"import luigi"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Examples of using gokart at jupyter notebook"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Basic Usage\n",
"This is a very basic usage, just to dump a run result of ExampleTaskA."
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"example_2\n"
]
}
],
"source": [
"class ExampleTaskA(gokart.TaskOnKart):\n",
" param = luigi.Parameter()\n",
" int_param = luigi.IntParameter(default=2)\n",
"\n",
" def run(self):\n",
" self.dump(f'DONE {self.param}_{self.int_param}')\n",
"\n",
" \n",
"task_a = ExampleTaskA(param='example')\n",
"output = gokart.build(task=task_a)\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make tasks dependencies with `requires()`\n",
"ExampleTaskB is dependent on ExampleTaskC and ExampleTaskD. They are defined in `ExampleTaskB.requires()`."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DONE example_TASKC_TASKD\n"
]
}
],
"source": [
"class ExampleTaskC(gokart.TaskOnKart):\n",
" def run(self):\n",
" self.dump('TASKC')\n",
" \n",
"class ExampleTaskD(gokart.TaskOnKart):\n",
" def run(self):\n",
" self.dump('TASKD')\n",
"\n",
"class ExampleTaskB(gokart.TaskOnKart):\n",
" param = luigi.Parameter()\n",
"\n",
" def requires(self):\n",
" return dict(task_c=ExampleTaskC(), task_d=ExampleTaskD())\n",
"\n",
" def run(self):\n",
" task_c = self.load('task_c')\n",
" task_d = self.load('task_d')\n",
" self.dump(f'DONE {self.param}_{task_c}_{task_d}')\n",
" \n",
"task_b = ExampleTaskB(param='example')\n",
"output = gokart.build(task=task_b)\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Make tasks dependencies with TaskInstanceParameter\n",
"The dependencies are same as previous example, however they are defined at the outside of the task instead of defied at `ExampleTaskB.requires()`."
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"DONE example_TASKC_TASKD\n"
]
}
],
"source": [
"class ExampleTaskC(gokart.TaskOnKart):\n",
" def run(self):\n",
" self.dump('TASKC')\n",
" \n",
"class ExampleTaskD(gokart.TaskOnKart):\n",
" def run(self):\n",
" self.dump('TASKD')\n",
"\n",
"class ExampleTaskB(gokart.TaskOnKart):\n",
" param = luigi.Parameter()\n",
" task_1 = gokart.TaskInstanceParameter()\n",
" task_2 = gokart.TaskInstanceParameter()\n",
"\n",
" def requires(self):\n",
" return dict(task_1=self.task_1, task_2=self.task_2) # required tasks are decided from the task parameters `task_1` and `task_2`\n",
"\n",
" def run(self):\n",
" task_1 = self.load('task_1')\n",
" task_2 = self.load('task_2')\n",
" self.dump(f'DONE {self.param}_{task_1}_{task_2}')\n",
" \n",
"task_b = ExampleTaskB(param='example', task_1=ExampleTaskC(), task_2=ExampleTaskD()) # Dependent tasks are defined here\n",
"output = gokart.build(task=task_b)\n",
"print(output)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.8.8 64-bit ('3.8.8': pyenv)",
"name": "python388jvsc74a57bd026997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
},
"metadata": {
"interpreter": {
"hash": "26997db2bf0f03e18da4e606f276befe0d6bf7cab2a6bb74742969d5bbde02ca"
}
},
"orig_nbformat": 3
},
"nbformat": 4,
"nbformat_minor": 2
}
================================================
FILE: examples/logging.ini
================================================
[loggers]
keys=root,luigi,luigi-interface,gokart
[handlers]
keys=stderrHandler
[formatters]
keys=simpleFormatter
[logger_root]
level=INFO
handlers=stderrHandler
[logger_gokart]
level=INFO
handlers=stderrHandler
qualname=gokart
propagate=0
[logger_luigi]
level=INFO
handlers=stderrHandler
qualname=luigi
propagate=0
[logger_luigi-interface]
level=INFO
handlers=stderrHandler
qualname=luigi-interface
propagate=0
[handler_stderrHandler]
class=StreamHandler
formatter=simpleFormatter
args=(sys.stdout,)
[formatter_simpleFormatter]
format=level=%(levelname)s time=%(asctime)s name=%(name)s file=%(filename)s line=%(lineno)d message=%(message)s
datefmt=%Y/%m/%d %H:%M:%S
class=logging.Formatter
================================================
FILE: examples/param.ini
================================================
[TaskOnKart]
workspace_directory=./resource
local_temporary_directory=./resource/tmp
[core]
logging_conf_file=logging.ini
================================================
FILE: gokart/__init__.py
================================================
__all__ = [
'build',
'WorkerSchedulerFactory',
'make_tree_info',
'tree_info',
'PandasTypeConfig',
'ExplicitBoolParameter',
'ListTaskInstanceParameter',
'SerializableParameter',
'TaskInstanceParameter',
'ZonedDateSecondParameter',
'run',
'TaskOnKart',
'test_run',
'make_task_info_as_tree_str',
'add_config',
'delete_local_unnecessary_outputs',
]
from gokart.build import WorkerSchedulerFactory, build
from gokart.info import make_tree_info, tree_info
from gokart.pandas_type_config import PandasTypeConfig
from gokart.parameter import (
ExplicitBoolParameter,
ListTaskInstanceParameter,
SerializableParameter,
TaskInstanceParameter,
ZonedDateSecondParameter,
)
from gokart.run import run
from gokart.task import TaskOnKart
from gokart.testing import test_run
from gokart.tree.task_info import make_task_info_as_tree_str
from gokart.utils import add_config
from gokart.workspace_management import delete_local_unnecessary_outputs
================================================
FILE: gokart/build.py
================================================
from __future__ import annotations
import enum
import io
import logging
from dataclasses import dataclass
from functools import partial
from logging import getLogger
from typing import Any, Literal, Protocol, TypeVar, cast, overload
import backoff
import luigi
from luigi import LuigiStatusCode, rpc, scheduler
import gokart
import gokart.tree.task_info
from gokart import worker
from gokart.conflict_prevention_lock.task_lock import TaskLockException
from gokart.target import TargetOnKart
from gokart.task import TaskOnKart
T = TypeVar('T')
logger: logging.Logger = logging.getLogger(__name__)
class LoggerConfig:
def __init__(self, level: int):
self.logger = getLogger(__name__)
self.default_level = self.logger.level
self.level = level
def __enter__(self):
logging.disable(self.level - 10) # subtract 10 to disable below self.level
self.logger.setLevel(self.level)
return self
def __exit__(self, exception_type, exception_value, traceback):
logging.disable(self.default_level - 10) # subtract 10 to disable below self.level
self.logger.setLevel(self.default_level)
class GokartBuildError(Exception):
"""Raised when ``gokart.build`` failed. This exception contains raised exceptions in the task execution."""
def __init__(self, message: str, raised_exceptions: dict[str, list[Exception]]) -> None:
super().__init__(message)
self.raised_exceptions = raised_exceptions
class HasLockedTaskException(Exception):
"""Raised when the task failed to acquire the lock in the task execution."""
class TaskLockExceptionRaisedFlag:
def __init__(self):
self.flag: bool = False
class WorkerProtocol(Protocol):
"""Protocol for Worker.
This protocol is determined by luigi.worker.Worker.
"""
def add(self, task: TaskOnKart[Any]) -> bool: ...
def run(self) -> bool: ...
def __enter__(self) -> WorkerProtocol: ...
def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]: ...
class WorkerSchedulerFactory:
def create_local_scheduler(self) -> scheduler.Scheduler:
return scheduler.Scheduler(prune_on_get_work=True, record_task_history=False)
def create_remote_scheduler(self, url: str) -> rpc.RemoteScheduler:
return rpc.RemoteScheduler(url)
def create_worker(self, scheduler: scheduler.Scheduler, worker_processes: int, assistant: bool = False) -> WorkerProtocol:
return worker.Worker(scheduler=scheduler, worker_processes=worker_processes, assistant=assistant)
def _get_output(task: TaskOnKart[T]) -> T:
output = task.output()
# FIXME: currently, nested output is not supported
if isinstance(output, list) or isinstance(output, tuple):
return cast(T, [t.load() for t in output if isinstance(t, TargetOnKart)])
if isinstance(output, dict):
return cast(T, {k: t.load() for k, t in output.items() if isinstance(t, TargetOnKart)})
if isinstance(output, TargetOnKart):
return cast(T, output.load())
raise ValueError(f'output type is not supported: {type(output)}')
def _reset_register(keep={'gokart', 'luigi'}):
"""reset luigi.task_register.Register._reg everytime gokart.build called to avoid TaskClassAmbigiousException"""
luigi.task_register.Register._reg = [
x
for x in luigi.task_register.Register._reg
if (
(x.__module__.split('.')[0] in keep) # keep luigi and gokart
or (issubclass(x, gokart.PandasTypeConfig))
) # PandasTypeConfig should be kept
]
class TaskDumpMode(enum.Enum):
TREE = 'tree'
TABLE = 'table'
NONE = 'none'
class TaskDumpOutputType(enum.Enum):
PRINT = 'print'
DUMP = 'dump'
NONE = 'none'
@dataclass
class TaskDumpConfig:
mode: TaskDumpMode = TaskDumpMode.NONE
output_type: TaskDumpOutputType = TaskDumpOutputType.NONE
def process_task_info(task: TaskOnKart[Any], task_dump_config: TaskDumpConfig = TaskDumpConfig()) -> None:
match task_dump_config:
case TaskDumpConfig(mode=TaskDumpMode.NONE, output_type=TaskDumpOutputType.NONE):
pass
case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.PRINT):
tree = gokart.make_tree_info(task)
logger.info(tree)
case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.PRINT):
table = gokart.tree.task_info.make_task_info_as_table(task)
output = io.StringIO()
table.to_csv(output, index=False, sep='\t')
output.seek(0)
logger.info(output.read())
case TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.DUMP):
tree = gokart.make_tree_info(task)
gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.txt').dump(tree)
case TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.DUMP):
table = gokart.tree.task_info.make_task_info_as_table(task)
gokart.TaskOnKart().make_target(f'log/task_info/{type(task).__name__}.pkl').dump(table)
case _:
raise ValueError(f'Unsupported TaskDumpConfig: {task_dump_config}')
@overload
def build(
task: TaskOnKart[T],
return_value: Literal[True] = True,
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
**env_params: Any,
) -> T: ...
@overload
def build(
task: TaskOnKart[T],
return_value: Literal[False],
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
**env_params: Any,
) -> None: ...
def build(
task: TaskOnKart[T],
return_value: bool = True,
reset_register: bool = True,
log_level: int = logging.ERROR,
task_lock_exception_max_tries: int = 10,
task_lock_exception_max_wait_seconds: int = 600,
task_dump_config: TaskDumpConfig = TaskDumpConfig(),
**env_params: Any,
) -> T | None:
"""
Run gokart task for local interpreter.
Sharing the most of its parameters with luigi.build (see https://luigi.readthedocs.io/en/stable/api/luigi.html?highlight=build#luigi.build)
"""
if reset_register:
_reset_register()
with LoggerConfig(level=log_level):
log_handler_before_run = logging.StreamHandler()
logger.addHandler(log_handler_before_run)
process_task_info(task, task_dump_config)
logger.removeHandler(log_handler_before_run)
log_handler_before_run.close()
task_lock_exception_raised = TaskLockExceptionRaisedFlag()
raised_exceptions: dict[str, list[Exception]] = dict()
@TaskOnKart.event_handler(luigi.Event.FAILURE)
def when_failure(task, exception):
if isinstance(exception, TaskLockException):
task_lock_exception_raised.flag = True
else:
raised_exceptions.setdefault(task.make_unique_id(), []).append(exception)
@backoff.on_exception(
partial(backoff.expo, max_value=task_lock_exception_max_wait_seconds), HasLockedTaskException, max_tries=task_lock_exception_max_tries
)
def _build_task():
task_lock_exception_raised.flag = False
result = luigi.build(
[task],
worker_scheduler_factory=WorkerSchedulerFactory(),
local_scheduler=True,
detailed_summary=True,
log_level=logging.getLevelName(log_level),
**env_params,
)
if task_lock_exception_raised.flag:
raise HasLockedTaskException()
if result.status in (LuigiStatusCode.FAILED, LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED, LuigiStatusCode.SCHEDULING_FAILED):
raise GokartBuildError(result.summary_text, raised_exceptions=raised_exceptions)
return _get_output(task) if return_value else None
return cast(T | None, _build_task())
================================================
FILE: gokart/config_params.py
================================================
from __future__ import annotations
from typing import Any
import luigi
import gokart
class inherits_config_params:
def __init__(self, config_class: type[luigi.Config], parameter_alias: dict[str, str] | None = None):
"""
Decorates task to inherit parameter value of `config_class`.
* config_class: Inherit parameter value of this task to decorated task. Only parameter values exist in both tasks are inherited.
* parameter_alias: Dictionary to map paramter names between config_class task and decorated task.
key: config_class's parameter name. value: decorated task's parameter name.
"""
self._config_class: type[luigi.Config] = config_class
self._parameter_alias: dict[str, str] = parameter_alias if parameter_alias is not None else {}
def __call__(self, task_class: type[gokart.TaskOnKart[Any]]) -> type[gokart.TaskOnKart[Any]]:
# wrap task to prevent task name from being changed
@luigi.task._task_wraps(task_class)
class Wrapped(task_class): # type: ignore
@classmethod
def get_param_values(cls, params, args, kwargs):
for param_key, param_value in self._config_class().param_kwargs.items():
task_param_key = self._parameter_alias.get(param_key, param_key)
if hasattr(cls, task_param_key) and task_param_key not in kwargs:
kwargs[task_param_key] = param_value
return super().get_param_values(params, args, kwargs)
return Wrapped
================================================
FILE: gokart/conflict_prevention_lock/task_lock.py
================================================
from __future__ import annotations
import functools
import os
from logging import getLogger
from typing import Any, NamedTuple
import redis
from apscheduler.schedulers.background import BackgroundScheduler
logger = getLogger(__name__)
class TaskLockParams(NamedTuple):
redis_host: str | None
redis_port: int | None
redis_timeout: int | None
redis_key: str
should_task_lock: bool
raise_task_lock_exception_on_collision: bool
lock_extend_seconds: int
class TaskLockException(Exception):
pass
"""Raised when the task failed to acquire the lock in the task execution. Only used internally."""
class RedisClient:
_instances: dict[Any, Any] = {}
def __new__(cls, *args, **kwargs):
key = (args, tuple(sorted(kwargs.items())))
if cls not in cls._instances:
cls._instances[cls] = {}
if key not in cls._instances[cls]:
cls._instances[cls][key] = super().__new__(cls)
return cls._instances[cls][key]
def __init__(self, host: str | None, port: int | None) -> None:
if not hasattr(self, '_redis_client'):
host = host or 'localhost'
port = port or 6379
self._redis_client = redis.Redis(host=host, port=port)
def get_redis_client(self):
return self._redis_client
def _extend_lock(task_lock: redis.lock.Lock, redis_timeout: int) -> None:
task_lock.extend(additional_time=redis_timeout, replace_ttl=True)
def set_task_lock(task_lock_params: TaskLockParams) -> redis.lock.Lock:
redis_client = RedisClient(host=task_lock_params.redis_host, port=task_lock_params.redis_port).get_redis_client()
blocking = not task_lock_params.raise_task_lock_exception_on_collision
task_lock = redis.lock.Lock(redis=redis_client, name=task_lock_params.redis_key, timeout=task_lock_params.redis_timeout, thread_local=False)
if not task_lock.acquire(blocking=blocking):
raise TaskLockException('Lock already taken by other task.')
return task_lock
def set_lock_scheduler(task_lock: redis.lock.Lock, task_lock_params: TaskLockParams) -> BackgroundScheduler:
scheduler = BackgroundScheduler()
extend_lock = functools.partial(_extend_lock, task_lock=task_lock, redis_timeout=task_lock_params.redis_timeout or 0)
scheduler.add_job(
extend_lock,
'interval',
seconds=task_lock_params.lock_extend_seconds,
max_instances=999999999,
misfire_grace_time=task_lock_params.redis_timeout,
coalesce=False,
)
scheduler.start()
return scheduler
def make_task_lock_key(file_path: str, unique_id: str | None) -> str:
basename_without_ext = os.path.splitext(os.path.basename(file_path))[0]
return f'{basename_without_ext}_{unique_id}'
def make_task_lock_params(
file_path: str,
unique_id: str | None,
redis_host: str | None = None,
redis_port: int | None = None,
redis_timeout: int | None = None,
raise_task_lock_exception_on_collision: bool = False,
lock_extend_seconds: int = 10,
) -> TaskLockParams:
redis_key = make_task_lock_key(file_path, unique_id)
should_task_lock = redis_host is not None and redis_port is not None
if redis_timeout is not None:
assert redis_timeout > lock_extend_seconds, f'`redis_timeout` must be set greater than lock_extend_seconds:{lock_extend_seconds}, not {redis_timeout}.'
task_lock_params = TaskLockParams(
redis_host=redis_host,
redis_port=redis_port,
redis_key=redis_key,
should_task_lock=should_task_lock,
redis_timeout=redis_timeout,
raise_task_lock_exception_on_collision=raise_task_lock_exception_on_collision,
lock_extend_seconds=lock_extend_seconds,
)
return task_lock_params
def make_task_lock_params_for_run(task_self: Any, lock_extend_seconds: int = 10) -> TaskLockParams:
task_path_name = os.path.join(task_self.__module__.replace('.', '/'), f'{type(task_self).__name__}')
unique_id = task_self.make_unique_id() + '-run'
task_lock_key = make_task_lock_key(file_path=task_path_name, unique_id=unique_id)
should_task_lock = task_self.redis_host is not None and task_self.redis_port is not None
return TaskLockParams(
redis_host=task_self.redis_host,
redis_port=task_self.redis_port,
redis_key=task_lock_key,
should_task_lock=should_task_lock,
redis_timeout=task_self.redis_timeout,
raise_task_lock_exception_on_collision=True,
lock_extend_seconds=lock_extend_seconds,
)
================================================
FILE: gokart/conflict_prevention_lock/task_lock_wrappers.py
================================================
from __future__ import annotations
import functools
from collections.abc import Callable
from logging import getLogger
from typing import ParamSpec, TypeVar
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, set_lock_scheduler, set_task_lock
logger = getLogger(__name__)
P = ParamSpec('P')
R = TypeVar('R')
def wrap_dump_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams, exist_check: Callable[..., bool]) -> Callable[P, R | None]:
"""Redis lock wrapper function for TargetOnKart.dump().
When TargetOnKart.dump() is called, dump() will be wrapped with redis lock and cache existance check.
https://github.com/m3dev/gokart/issues/265
"""
if not task_lock_params.should_task_lock:
return func
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | None:
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
try:
logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} locked.')
if not exist_check():
return func(*args, **kwargs)
return None
finally:
logger.debug(f'Task DUMP lock of {task_lock_params.redis_key} released.')
task_lock.release()
scheduler.shutdown()
return wrapper
def wrap_load_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]:
"""Redis lock wrapper function for TargetOnKart.load().
When TargetOnKart.load() is called, redis lock will be locked and released before load().
https://github.com/m3dev/gokart/issues/265
"""
if not task_lock_params.should_task_lock:
return func
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} locked.')
task_lock.release()
logger.debug(f'Task LOAD lock of {task_lock_params.redis_key} released.')
scheduler.shutdown()
result = func(*args, **kwargs)
return result
return wrapper
def wrap_remove_with_lock(func: Callable[P, R], task_lock_params: TaskLockParams) -> Callable[P, R]:
"""Redis lock wrapper function for TargetOnKart.remove().
When TargetOnKart.remove() is called, remove() will be simply wrapped with redis lock.
https://github.com/m3dev/gokart/issues/265
"""
if not task_lock_params.should_task_lock:
return func
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
try:
logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} locked.')
result = func(*args, **kwargs)
task_lock.release()
logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released.')
scheduler.shutdown()
return result
except BaseException as e:
logger.debug(f'Task REMOVE lock of {task_lock_params.redis_key} released with BaseException.')
task_lock.release()
scheduler.shutdown()
raise e
return wrapper
def wrap_run_with_lock(run_func: Callable[[], R], task_lock_params: TaskLockParams) -> Callable[[], R]:
@functools.wraps(run_func)
def wrapped():
task_lock = set_task_lock(task_lock_params=task_lock_params)
scheduler = set_lock_scheduler(task_lock=task_lock, task_lock_params=task_lock_params)
try:
logger.debug(f'Task RUN lock of {task_lock_params.redis_key} locked.')
result = run_func()
task_lock.release()
logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released.')
scheduler.shutdown()
return result
except BaseException as e:
logger.debug(f'Task RUN lock of {task_lock_params.redis_key} released with BaseException.')
task_lock.release()
scheduler.shutdown()
raise e
return wrapped
================================================
FILE: gokart/errors/__init__.py
================================================
from gokart.build import GokartBuildError, HasLockedTaskException
from gokart.pandas_type_config import PandasTypeError
from gokart.task import EmptyDumpError
__all__ = [
'GokartBuildError',
'HasLockedTaskException',
'PandasTypeError',
'EmptyDumpError',
]
================================================
FILE: gokart/file_processor/__init__.py
================================================
"""File processor module with support for multiple DataFrame backends."""
from __future__ import annotations
import os
from typing import Any, Literal
# Export common processors and types from base
from gokart.file_processor.base import (
BinaryFileProcessor,
DataFrameType,
FileProcessor,
GzipFileProcessor,
NpzFileProcessor,
PickleFileProcessor,
TextFileProcessor,
XmlFileProcessor,
)
# Import backend-specific implementations
from gokart.file_processor.pandas import (
CsvFileProcessorPandas,
FeatherFileProcessorPandas,
JsonFileProcessorPandas,
ParquetFileProcessorPandas,
)
from gokart.file_processor.polars import (
CsvFileProcessorPolars,
FeatherFileProcessorPolars,
JsonFileProcessorPolars,
ParquetFileProcessorPolars,
)
class CsvFileProcessor(FileProcessor):
"""CSV file processor with automatic backend selection based on dataframe_type."""
def __init__(self, sep: str = ',', encoding: str = 'utf-8', dataframe_type: DataFrameType = 'pandas') -> None:
"""
CSV file processor with support for both pandas and polars DataFrames.
Args:
sep: CSV delimiter (default: ',')
encoding: File encoding (default: 'utf-8')
dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas')
"""
self._sep = sep
self._encoding = encoding
self._dataframe_type = dataframe_type # Store for tests
if dataframe_type == 'polars-lazy':
self._impl: FileProcessor = CsvFileProcessorPolars(sep=sep, encoding=encoding, lazy=True)
elif dataframe_type == 'polars':
self._impl = CsvFileProcessorPolars(sep=sep, encoding=encoding, lazy=False)
else:
self._impl = CsvFileProcessorPandas(sep=sep, encoding=encoding)
def format(self):
return self._impl.format()
def load(self, file):
return self._impl.load(file)
def dump(self, obj, file):
return self._impl.dump(obj, file)
class JsonFileProcessor(FileProcessor):
"""JSON file processor with automatic backend selection based on dataframe_type."""
def __init__(self, orient: Literal['split', 'records', 'index', 'table', 'columns', 'values'] | None = None, dataframe_type: DataFrameType = 'pandas'):
"""
JSON file processor with support for both pandas and polars DataFrames.
Args:
orient: JSON orientation. 'records' for newline-delimited JSON.
dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas')
"""
self._orient = orient
self._dataframe_type = dataframe_type # Store for tests
if dataframe_type == 'polars-lazy':
self._impl: FileProcessor = JsonFileProcessorPolars(orient=orient, lazy=True)
elif dataframe_type == 'polars':
self._impl = JsonFileProcessorPolars(orient=orient, lazy=False)
else:
self._impl = JsonFileProcessorPandas(orient=orient)
def format(self):
return self._impl.format()
def load(self, file):
return self._impl.load(file)
def dump(self, obj, file):
return self._impl.dump(obj, file)
class ParquetFileProcessor(FileProcessor):
"""Parquet file processor with automatic backend selection based on dataframe_type."""
def __init__(self, engine: Any = 'pyarrow', compression: Any = None, dataframe_type: DataFrameType = 'pandas') -> None:
"""
Parquet file processor with support for both pandas and polars DataFrames.
Args:
engine: Parquet engine (pandas-specific, ignored for polars).
compression: Compression type.
dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas')
"""
self._engine = engine
self._compression = compression
self._dataframe_type = dataframe_type # Store for tests
if dataframe_type == 'polars-lazy':
self._impl: FileProcessor = ParquetFileProcessorPolars(engine=engine, compression=compression, lazy=True)
elif dataframe_type == 'polars':
self._impl = ParquetFileProcessorPolars(engine=engine, compression=compression, lazy=False)
else:
self._impl = ParquetFileProcessorPandas(engine=engine, compression=compression)
def format(self):
return self._impl.format()
def load(self, file):
return self._impl.load(file)
def dump(self, obj, file):
# Use the configured implementation (pandas by default)
return self._impl.dump(obj, file)
class FeatherFileProcessor(FileProcessor):
"""Feather file processor with automatic backend selection based on dataframe_type."""
def __init__(self, store_index_in_feather: bool, dataframe_type: DataFrameType = 'pandas'):
"""
Feather file processor with support for both pandas and polars DataFrames.
Args:
store_index_in_feather: Whether to store pandas index (pandas-only feature).
dataframe_type: DataFrame library to use for load() - 'pandas', 'polars', or 'polars-lazy' (default: 'pandas')
"""
self._store_index_in_feather = store_index_in_feather
self._dataframe_type = dataframe_type # Store for tests
if dataframe_type == 'polars-lazy':
self._impl: FileProcessor = FeatherFileProcessorPolars(store_index_in_feather=store_index_in_feather, lazy=True)
elif dataframe_type == 'polars':
self._impl = FeatherFileProcessorPolars(store_index_in_feather=store_index_in_feather, lazy=False)
else:
self._impl = FeatherFileProcessorPandas(store_index_in_feather=store_index_in_feather)
def format(self):
return self._impl.format()
def load(self, file):
return self._impl.load(file)
def dump(self, obj, file):
# Use the configured implementation (pandas by default)
return self._impl.dump(obj, file)
def make_file_processor(file_path: str, store_index_in_feather: bool = True, *, dataframe_type: DataFrameType = 'pandas') -> FileProcessor:
"""Create a file processor based on file extension with default parameters."""
extension2processor = {
'.txt': TextFileProcessor(),
'.ini': TextFileProcessor(),
'.csv': CsvFileProcessor(sep=',', dataframe_type=dataframe_type),
'.tsv': CsvFileProcessor(sep='\t', dataframe_type=dataframe_type),
'.pkl': PickleFileProcessor(),
'.gz': GzipFileProcessor(),
'.json': JsonFileProcessor(dataframe_type=dataframe_type),
'.ndjson': JsonFileProcessor(dataframe_type=dataframe_type, orient='records'),
'.xml': XmlFileProcessor(),
'.npz': NpzFileProcessor(),
'.parquet': ParquetFileProcessor(compression='gzip', dataframe_type=dataframe_type),
'.feather': FeatherFileProcessor(store_index_in_feather=store_index_in_feather, dataframe_type=dataframe_type),
'.png': BinaryFileProcessor(),
'.jpg': BinaryFileProcessor(),
}
extension = os.path.splitext(file_path)[1]
assert extension in extension2processor, f'{extension} is not supported. The supported extensions are {list(extension2processor.keys())}.'
return extension2processor[extension]
__all__ = [
# Base classes and types
'FileProcessor',
'DataFrameType',
# Common processors
'BinaryFileProcessor',
'PickleFileProcessor',
'TextFileProcessor',
'GzipFileProcessor',
'XmlFileProcessor',
'NpzFileProcessor',
# DataFrame processors (with factory pattern)
'CsvFileProcessor',
'JsonFileProcessor',
'ParquetFileProcessor',
'FeatherFileProcessor',
# Utility functions
'make_file_processor',
]
================================================
FILE: gokart/file_processor/base.py
================================================
from __future__ import annotations
import xml.etree.ElementTree as ET
from abc import abstractmethod
from io import BytesIO
from logging import getLogger
from typing import Any, Literal, cast
import dill
import luigi
import luigi.format
import numpy as np
from gokart.utils import load_dill_with_pandas_backward_compatibility
logger = getLogger(__name__)
# Type alias for DataFrame library return type
DataFrameType = Literal['pandas', 'polars', 'polars-lazy']
class FileProcessor:
@abstractmethod
def format(self) -> Any: ...
@abstractmethod
def load(self, file: Any) -> Any: ...
@abstractmethod
def dump(self, obj: Any, file: Any) -> None: ...
class BinaryFileProcessor(FileProcessor):
"""
Pass bytes to this processor
```
figure_binary = io.BytesIO()
plt.savefig(figure_binary)
figure_binary.seek(0)
BinaryFileProcessor().dump(figure_binary.read())
```
"""
def format(self):
return luigi.format.Nop
def load(self, file):
return file.read()
def dump(self, obj, file):
file.write(obj)
class _ChunkedLargeFileReader:
def __init__(self, file: Any) -> None:
self._file = file
def __getattr__(self, item):
return getattr(self._file, item)
def read(self, n: int) -> bytes:
if n >= (1 << 31):
logger.info(f'reading a large file with total_bytes={n}.')
buffer = bytearray(n)
idx = 0
while idx < n:
batch_size = min(n - idx, (1 << 31) - 1)
logger.info(f'reading bytes [{idx}, {idx + batch_size})...')
buffer[idx : idx + batch_size] = self._file.read(batch_size)
idx += batch_size
logger.info('done.')
return bytes(buffer)
return cast(bytes, self._file.read(n))
def readline(self) -> bytes:
return cast(bytes, self._file.readline())
def seek(self, offset: int) -> None:
self._file.seek(offset)
def seekable(self) -> bool:
return cast(bool, self._file.seekable())
class PickleFileProcessor(FileProcessor):
def format(self):
return luigi.format.Nop
def load(self, file):
if not file.seekable():
# load_dill_with_pandas_backward_compatibility() requires file with seek() and readlines() implemented.
# Therefore, we need to wrap with BytesIO which makes file seekable and readlinesable.
# For example, ReadableS3File is not a seekable file.
return load_dill_with_pandas_backward_compatibility(BytesIO(file.read()))
return load_dill_with_pandas_backward_compatibility(_ChunkedLargeFileReader(file))
def dump(self, obj, file):
self._write(dill.dumps(obj, protocol=4), file)
@staticmethod
def _write(buffer, file):
n = len(buffer)
idx = 0
while idx < n:
logger.info(f'writing a file with total_bytes={n}...')
batch_size = min(n - idx, (1 << 31) - 1)
logger.info(f'writing bytes [{idx}, {idx + batch_size})')
file.write(buffer[idx : idx + batch_size])
idx += batch_size
logger.info('done')
class TextFileProcessor(FileProcessor):
def format(self):
return None
def load(self, file):
return [s.rstrip() for s in file.readlines()]
def dump(self, obj, file):
if isinstance(obj, list):
for x in obj:
file.write(str(x) + '\n')
else:
file.write(str(obj))
class GzipFileProcessor(FileProcessor):
def format(self):
return luigi.format.Gzip
def load(self, file):
return [s.rstrip().decode() for s in file.readlines()]
def dump(self, obj, file):
if isinstance(obj, list):
for x in obj:
file.write((str(x) + '\n').encode())
else:
file.write(str(obj).encode())
class XmlFileProcessor(FileProcessor):
def format(self):
return None
def load(self, file):
try:
return ET.parse(file)
except ET.ParseError:
return ET.ElementTree()
def dump(self, obj, file):
assert isinstance(obj, ET.ElementTree), f'requires ET.ElementTree, but {type(obj)} is passed.'
obj.write(file)
class NpzFileProcessor(FileProcessor):
def format(self):
return luigi.format.Nop
def load(self, file):
return np.load(file)['data']
def dump(self, obj, file):
assert isinstance(obj, np.ndarray), f'requires np.ndarray, but {type(obj)} is passed.'
np.savez_compressed(file, data=obj)
================================================
FILE: gokart/file_processor/pandas.py
================================================
"""Pandas-specific file processor implementations."""
from __future__ import annotations
from io import BytesIO
from typing import Literal
import luigi
import luigi.format
import pandas as pd
from luigi.format import TextFormat
from gokart.file_processor.base import FileProcessor
from gokart.object_storage import ObjectStorage
class CsvFileProcessorPandas(FileProcessor):
"""CSV file processor for pandas DataFrames."""
def __init__(self, sep: str = ',', encoding: str = 'utf-8') -> None:
self._sep = sep
self._encoding = encoding
super().__init__()
def format(self):
return TextFormat(encoding=self._encoding)
def load(self, file):
try:
return pd.read_csv(file, sep=self._sep, encoding=self._encoding)
except pd.errors.EmptyDataError:
return pd.DataFrame()
def dump(self, obj, file):
if not isinstance(obj, pd.DataFrame | pd.Series):
raise TypeError(f'requires pd.DataFrame or pd.Series, but {type(obj)} is passed.')
obj.to_csv(file, mode='wt', index=False, sep=self._sep, header=True, encoding=self._encoding)
_JsonOrient = Literal['split', 'records', 'index', 'table', 'columns', 'values']
class JsonFileProcessorPandas(FileProcessor):
"""JSON file processor for pandas DataFrames."""
def __init__(self, orient: _JsonOrient | None = None):
self._orient: _JsonOrient | None = orient
def format(self):
return luigi.format.Nop
def load(self, file):
try:
return pd.read_json(file, orient=self._orient, lines=True if self._orient == 'records' else False)
except pd.errors.EmptyDataError:
return pd.DataFrame()
def dump(self, obj, file):
if isinstance(obj, dict):
obj = pd.DataFrame.from_dict(obj)
if not isinstance(obj, pd.DataFrame | pd.Series):
raise TypeError(f'requires pd.DataFrame or pd.Series or dict, but {type(obj)} is passed.')
obj.to_json(file, orient=self._orient, lines=True if self._orient == 'records' else False)
class ParquetFileProcessorPandas(FileProcessor):
"""Parquet file processor for pandas DataFrames."""
def __init__(self, engine: Literal['auto', 'pyarrow', 'fastparquet'] = 'pyarrow', compression: str | None = None) -> None:
self._engine: Literal['auto', 'pyarrow', 'fastparquet'] = engine
self._compression = compression
super().__init__()
def format(self):
return luigi.format.Nop
def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_parquet accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
return pd.read_parquet(file.name)
else:
return pd.read_parquet(BytesIO(file.read()))
def dump(self, obj, file):
if not isinstance(obj, pd.DataFrame):
raise TypeError(f'requires pd.DataFrame, but {type(obj)} is passed.')
# MEMO: to_parquet only supports a filepath as string (not a file handle)
obj.to_parquet(file.name, index=False, engine=self._engine, compression=self._compression)
class FeatherFileProcessorPandas(FileProcessor):
"""Feather file processor for pandas DataFrames."""
def __init__(self, store_index_in_feather: bool):
super().__init__()
self._store_index_in_feather = store_index_in_feather
self.INDEX_COLUMN_PREFIX = '__feather_gokart_index__'
def format(self):
return luigi.format.Nop
def load(self, file):
# FIXME(mamo3gr): enable streaming (chunked) read with S3.
# pandas.read_feather accepts file-like object
# but file (luigi.contrib.s3.ReadableS3File) should have 'tell' method,
# which is needed for pandas to read a file in chunks.
if ObjectStorage.is_buffered_reader(file):
loaded_df = pd.read_feather(file.name)
else:
loaded_df = pd.read_feather(BytesIO(file.read()))
if self._store_index_in_feather:
if any(col.startswith(self.INDEX_COLUMN_PREFIX) for col in loaded_df.columns):
index_columns = [col_name for col_name in loaded_df.columns[::-1] if col_name[: len(self.INDEX_COLUMN_PREFIX)] == self.INDEX_COLUMN_PREFIX]
index_column = index_columns[0]
index_name = index_column[len(self.INDEX_COLUMN_PREFIX) :]
if index_name == 'None':
index_name = None
loaded_df.index = pd.Index(loaded_df[index_column].values, name=index_name)
loaded_df = loaded_df.drop(columns=[index_column])
return loaded_df
def dump(self, obj, file):
if not isinstance(obj, pd.DataFrame):
raise TypeError(f'requires pd.DataFrame, but {type(obj)} is passed.')
dump_obj = obj.copy()
if self._store_index_in_feather:
index_column_name = f'{self.INDEX_COLUMN_PREFIX}{dump_obj.index.name}'
assert index_column_name not in dump_obj.columns, (
f'column name {index_column_name} already exists in dump_obj. \nConsider not saving index by setting store_index_in_feather=False.'
)
assert dump_obj.index.name != 'None', 'index name is "None", which is not allowed in gokart. Consider setting another index name.'
dump_obj[index_column_name] = dump_obj.index
dump_obj = dump_obj.reset_index(drop=True)
# to_feather supports "binary" file-like object, but file variable is text
dump_obj.to_feather(file.name)
================================================
FILE: gokart/file_processor/polars.py
================================================
"""Polars-specific file processor implementations."""
from __future__ import annotations
from io import BytesIO
from typing import TYPE_CHECKING, Literal
import luigi
import luigi.format
from luigi.format import TextFormat
from gokart.file_processor.base import FileProcessor
from gokart.object_storage import ObjectStorage
_CsvEncoding = Literal['utf8', 'utf8-lossy']
_ParquetCompression = Literal['lz4', 'uncompressed', 'snappy', 'gzip', 'brotli', 'zstd']
try:
import polars as pl
HAS_POLARS = True
except ImportError:
HAS_POLARS = False
if TYPE_CHECKING:
import polars as pl
class CsvFileProcessorPolars(FileProcessor):
"""CSV file processor for polars DataFrames."""
def __init__(self, sep: str = ',', encoding: str = 'utf-8', lazy: bool = False) -> None:
if not HAS_POLARS:
raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars")
self._sep = sep
self._encoding = encoding
self._lazy = lazy
super().__init__()
def format(self):
return TextFormat(encoding=self._encoding)
def load(self, file):
try:
# scan_csv/read_csv only support 'utf8' and 'utf8-lossy'
encoding: _CsvEncoding = 'utf8' if self._encoding in ('utf-8', 'utf8') else 'utf8-lossy'
if self._lazy:
# scan_csv requires a file path, not a file object
return pl.scan_csv(file.name, separator=self._sep, encoding=encoding)
return pl.read_csv(file, separator=self._sep, encoding=encoding)
except Exception as e:
# Handle empty data gracefully
if 'empty' in str(e).lower() or 'no data' in str(e).lower():
return pl.LazyFrame() if self._lazy else pl.DataFrame()
raise
def dump(self, obj, file):
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
if not isinstance(obj, pl.DataFrame):
raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.')
obj.write_csv(file, separator=self._sep, include_header=True)
class JsonFileProcessorPolars(FileProcessor):
"""JSON file processor for polars DataFrames."""
def __init__(self, orient: str | None = None, lazy: bool = False):
if not HAS_POLARS:
raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars")
self._orient = orient
self._lazy = lazy
def format(self):
return luigi.format.Nop
def load(self, file):
try:
if self._orient == 'records':
if self._lazy:
return pl.scan_ndjson(file)
return pl.read_ndjson(file)
else:
# polars doesn't have scan_json, so we read and convert if lazy
df = pl.read_json(file)
return df.lazy() if self._lazy else df
except Exception as e:
# Handle empty files
if 'empty' in str(e).lower() or 'no data' in str(e).lower():
return pl.LazyFrame() if self._lazy else pl.DataFrame()
raise
def dump(self, obj, file):
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
if not isinstance(obj, pl.DataFrame):
raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.')
if self._orient == 'records':
obj.write_ndjson(file)
else:
obj.write_json(file)
class ParquetFileProcessorPolars(FileProcessor):
"""Parquet file processor for polars DataFrames."""
def __init__(self, engine: str = 'pyarrow', compression: _ParquetCompression | None = None, lazy: bool = False) -> None:
if not HAS_POLARS:
raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars")
self._engine = engine # Ignored for polars
self._compression: _ParquetCompression | None = compression
self._lazy = lazy
super().__init__()
def format(self):
return luigi.format.Nop
def load(self, file):
# polars.read_parquet can handle file paths or file-like objects
if ObjectStorage.is_buffered_reader(file):
if self._lazy:
return pl.scan_parquet(file.name)
return pl.read_parquet(file.name)
else:
data = BytesIO(file.read())
if self._lazy:
# scan_parquet doesn't work with BytesIO, so read and convert
return pl.read_parquet(data).lazy()
return pl.read_parquet(data)
def dump(self, obj, file):
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
if not isinstance(obj, pl.DataFrame):
raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.')
# polars write_parquet requires a file path; default to 'zstd' when compression is None
obj.write_parquet(file.name, compression=self._compression or 'zstd')
class FeatherFileProcessorPolars(FileProcessor):
"""Feather file processor for polars DataFrames."""
def __init__(self, store_index_in_feather: bool, lazy: bool = False):
if not HAS_POLARS:
raise ImportError("polars is required for polars-based dataframe types ('polars' or 'polars-lazy'). Install with: pip install polars")
super().__init__()
self._store_index_in_feather = store_index_in_feather # Ignored for polars
self._lazy = lazy
def format(self):
return luigi.format.Nop
def load(self, file):
# polars uses read_ipc for feather format
if ObjectStorage.is_buffered_reader(file):
if self._lazy:
return pl.scan_ipc(file.name)
return pl.read_ipc(file.name)
else:
data = BytesIO(file.read())
if self._lazy:
# scan_ipc doesn't work with BytesIO, so read and convert
return pl.read_ipc(data).lazy()
return pl.read_ipc(data)
def dump(self, obj, file):
if isinstance(obj, pl.LazyFrame):
obj = obj.collect()
if not isinstance(obj, pl.DataFrame):
raise TypeError(f'requires pl.DataFrame or pl.LazyFrame, but {type(obj)} is passed.')
# polars uses write_ipc for feather format
# Note: store_index_in_feather is ignored for polars as it's pandas-specific
obj.write_ipc(file.name)
================================================
FILE: gokart/file_processor.py
================================================
================================================
FILE: gokart/gcs_config.py
================================================
from __future__ import annotations
import json
import os
from typing import cast
import luigi
import luigi.contrib.gcs
from google.oauth2.service_account import Credentials
class GCSConfig(luigi.Config):
gcs_credential_name: luigi.StrParameter = luigi.StrParameter(default='GCS_CREDENTIAL', description='GCS credential environment variable.')
_client = None
def get_gcs_client(self) -> luigi.contrib.gcs.GCSClient:
if self._client is None: # use cache as like singleton object
self._client = self._get_gcs_client()
return self._client
def _get_gcs_client(self) -> luigi.contrib.gcs.GCSClient:
return luigi.contrib.gcs.GCSClient(oauth_credentials=self._load_oauth_credentials())
def _load_oauth_credentials(self) -> Credentials | None:
json_str = os.environ.get(self.gcs_credential_name)
if not json_str:
return None
if os.path.isfile(json_str):
return cast(Credentials, Credentials.from_service_account_file(json_str))
return cast(Credentials, Credentials.from_service_account_info(json.loads(json_str)))
================================================
FILE: gokart/gcs_obj_metadata_client.py
================================================
from __future__ import annotations
import copy
import functools
import json
import re
from collections.abc import Iterable
from logging import getLogger
from typing import Any, Final
from urllib.parse import urlsplit
from googleapiclient.model import makepatch
from gokart.gcs_config import GCSConfig
from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import FlattenableItems
logger = getLogger(__name__)
class GCSObjectMetadataClient:
"""
This class is Utility-Class, so should not be initialized.
This class used for adding metadata as labels.
"""
# Maximum metadata size for GCS objects (8 KiB)
MAX_GCS_METADATA_SIZE: Final[int] = 8 * 1024
@staticmethod
def _is_log_related_path(path: str) -> bool:
return re.match(r'^gs://.+?/log/(processing_time/|task_info/|task_log/|module_versions/|random_seed/|task_params/).+', path) is not None
# This is the copied method of luigi.gcs._path_to_bucket_and_key(path).
@staticmethod
def _path_to_bucket_and_key(path: str) -> tuple[str, str]:
(scheme, netloc, path, _, _) = urlsplit(path)
assert scheme == 'gs'
path_without_initial_slash = path[1:]
return netloc, path_without_initial_slash
@staticmethod
def add_task_state_labels(
path: str,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
if GCSObjectMetadataClient._is_log_related_path(path):
return
# In gokart/object_storage.get_time_stamp, could find same call.
# _path_to_bucket_and_key is a private method, so, this might not be acceptable.
bucket, obj = GCSObjectMetadataClient._path_to_bucket_and_key(path)
_response = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute()
if _response is None:
logger.error(f'failed to get object from GCS bucket {bucket} and object {obj}.')
return
response: dict[str, Any] = dict(_response)
original_metadata: dict[Any, Any] = {}
if 'metadata' in response.keys():
_metadata = response.get('metadata')
if _metadata is not None:
original_metadata = dict(_metadata)
patched_metadata = GCSObjectMetadataClient._get_patched_obj_metadata(
copy.deepcopy(original_metadata),
task_params,
custom_labels,
required_task_outputs,
)
if original_metadata != patched_metadata:
# If we use update api, existing object metadata are removed, so should use patch api.
# See the official document descriptions.
# [Link] https://cloud.google.com/storage/docs/viewing-editing-metadata?hl=ja#rest-set-object-metadata
update_response = (
GCSConfig()
.get_gcs_client()
.client.objects()
.patch(
bucket=bucket,
object=obj,
body=makepatch({'metadata': original_metadata}, {'metadata': patched_metadata}),
)
.execute()
)
if update_response is None:
logger.error(f'failed to patch object {obj} in bucket {bucket} and object {obj}.')
@staticmethod
def _normalize_labels(labels: dict[str, Any] | None) -> dict[str, str]:
return {str(key): str(value) for key, value in labels.items()} if labels else {}
@staticmethod
def _get_patched_obj_metadata(
metadata: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> dict[str, Any] | Any:
# If metadata from response when getting bucket and object information is not dictionary,
# something wrong might be happened, so return original metadata, no patched.
if not isinstance(metadata, dict):
logger.warning(f'metadata is not a dict: {metadata}, something wrong was happened when getting response when get bucket and object information.')
return metadata
# Maximum size of metadata for each object is 8 KiB.
# [Link]: https://cloud.google.com/storage/quotas#objects
normalized_task_params_labels = GCSObjectMetadataClient._normalize_labels(task_params)
normalized_custom_labels = GCSObjectMetadataClient._normalize_labels(custom_labels)
# There is a possibility that the keys of user-provided labels(custom_labels) may conflict with those generated from task parameters (task_params_labels).
# However, users who utilize custom_labels are no longer expected to search using the labels generated from task parameters.
# Instead, users are expected to search using the labels they provided.
# Therefore, in the event of a key conflict, the value registered by the user-provided labels will take precedence.
normalized_labels = [normalized_custom_labels, normalized_task_params_labels]
if required_task_outputs:
normalized_labels.append({'__required_task_outputs': json.dumps(GCSObjectMetadataClient._get_serialized_string(required_task_outputs))})
_merged_labels = GCSObjectMetadataClient._merge_custom_labels_and_task_params_labels(normalized_labels)
return GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(dict(metadata) | _merged_labels)
@staticmethod
def _get_serialized_string(required_task_outputs: FlattenableItems[RequiredTaskOutput]) -> FlattenableItems[str]:
if isinstance(required_task_outputs, RequiredTaskOutput):
return required_task_outputs.serialize()
elif isinstance(required_task_outputs, dict):
return {k: GCSObjectMetadataClient._get_serialized_string(v) for k, v in required_task_outputs.items()}
elif isinstance(required_task_outputs, Iterable):
return [GCSObjectMetadataClient._get_serialized_string(ro) for ro in required_task_outputs]
else:
raise TypeError(
f'Unsupported type for required_task_outputs: {type(required_task_outputs)}. '
'It should be RequiredTaskOutput, dict, or iterable of RequiredTaskOutput.'
)
@staticmethod
def _merge_custom_labels_and_task_params_labels(
normalized_labels_list: list[dict[str, str]],
) -> dict[str, str]:
def __merge_two_dicts_helper(merged: dict[str, str], current_labels: dict[str, str]) -> dict[str, str]:
next_merged = copy.deepcopy(merged)
for label_name, label_value in current_labels.items():
if len(label_value) == 0:
logger.warning(f'value of label_name={label_name} is empty. So skip to add as a metadata.')
continue
if label_name in next_merged:
logger.warning(f'label_name={label_name} is already seen. So skip to add as metadata.')
continue
next_merged[label_name] = label_value
return next_merged
return functools.reduce(__merge_two_dicts_helper, normalized_labels_list, {})
# Google Cloud Storage(GCS) has a limitation of metadata size, 8 KiB.
# So, we need to adjust the size of metadata.
@staticmethod
def _adjust_gcs_metadata_limit_size(_labels: dict[str, str]) -> dict[str, str]:
def _get_label_size(label_name: str, label_value: str) -> int:
return len(label_name.encode('utf-8')) + len(label_value.encode('utf-8'))
labels = copy.deepcopy(_labels)
max_gcs_metadata_size, current_total_metadata_size = (
GCSObjectMetadataClient.MAX_GCS_METADATA_SIZE,
sum(_get_label_size(label_name, label_value) for label_name, label_value in labels.items()),
)
if current_total_metadata_size <= max_gcs_metadata_size:
return labels
# NOTE: remove labels to stay within max metadata size.
to_remove = []
for label_name, label_value in reversed(tuple(labels.items())):
size = _get_label_size(label_name, label_value)
to_remove.append(label_name)
current_total_metadata_size -= size
if current_total_metadata_size <= max_gcs_metadata_size:
break
for key in to_remove:
del labels[key]
return labels
================================================
FILE: gokart/gcs_zip_client.py
================================================
from __future__ import annotations
import os
import shutil
from typing import cast
from gokart.gcs_config import GCSConfig
from gokart.zip_client import ZipClient, _unzip_file
class GCSZipClient(ZipClient):
def __init__(self, file_path: str, temporary_directory: str) -> None:
self._file_path = file_path
self._temporary_directory = temporary_directory
self._client = GCSConfig().get_gcs_client()
def exists(self) -> bool:
return cast(bool, self._client.exists(self._file_path))
def make_archive(self) -> None:
extension = os.path.splitext(self._file_path)[1]
shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory)
self._client.put(self._temporary_file_path(), self._file_path)
def unpack_archive(self) -> None:
os.makedirs(self._temporary_directory, exist_ok=True)
file_pointer = self._client.download(self._file_path)
_unzip_file(fp=file_pointer, extract_dir=self._temporary_directory)
def remove(self) -> None:
self._client.remove(self._file_path)
@property
def path(self) -> str:
return self._file_path
def _temporary_file_path(self):
extension = os.path.splitext(self._file_path)[1]
base_name = self._temporary_directory
if base_name.endswith('/'):
base_name = base_name[:-1]
return base_name + extension
================================================
FILE: gokart/in_memory/__init__.py
================================================
__all__ = [
'InMemoryCacheRepository',
'InMemoryTarget',
'make_in_memory_target',
]
from .repository import InMemoryCacheRepository
from .target import InMemoryTarget, make_in_memory_target
================================================
FILE: gokart/in_memory/data.py
================================================
from __future__ import annotations
from dataclasses import dataclass
from datetime import datetime
from typing import Any
@dataclass
class InMemoryData:
value: Any
last_modification_time: datetime
@classmethod
def create_data(self, value: Any) -> InMemoryData:
return InMemoryData(value=value, last_modification_time=datetime.now())
================================================
FILE: gokart/in_memory/repository.py
================================================
from __future__ import annotations
from collections.abc import Iterator
from datetime import datetime
from typing import Any
from .data import InMemoryData
class InMemoryCacheRepository:
_cache: dict[str, InMemoryData] = {}
def __init__(self):
pass
def get_value(self, key: str) -> Any:
return self._get_data(key).value
def get_last_modification_time(self, key: str) -> datetime:
return self._get_data(key).last_modification_time
def _get_data(self, key: str) -> InMemoryData:
return self._cache[key]
def set_value(self, key: str, obj: Any) -> None:
data = InMemoryData.create_data(obj)
self._cache[key] = data
def has(self, key: str) -> bool:
return key in self._cache
def remove(self, key: str) -> None:
assert self.has(key), f'{key} does not exist.'
del self._cache[key]
def empty(self) -> bool:
return not self._cache
def clear(self) -> None:
self._cache.clear()
def get_gen(self) -> Iterator[tuple[str, Any]]:
for key, data in self._cache.items():
yield key, data.value
@property
def size(self) -> int:
return len(self._cache)
================================================
FILE: gokart/in_memory/target.py
================================================
from __future__ import annotations
from datetime import datetime
from typing import Any
from gokart.in_memory.repository import InMemoryCacheRepository
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart, TaskLockParams
from gokart.utils import FlattenableItems
_repository = InMemoryCacheRepository()
class InMemoryTarget(TargetOnKart):
def __init__(self, data_key: str, task_lock_param: TaskLockParams):
if task_lock_param.should_task_lock:
raise ValueError('Redis with `InMemoryTarget` is not currently supported.')
self._data_key = data_key
self._task_lock_params = task_lock_param
def _exists(self) -> bool:
return _repository.has(self._data_key)
def _get_task_lock_params(self) -> TaskLockParams:
return self._task_lock_params
def _load(self) -> Any:
return _repository.get_value(self._data_key)
def _dump(
self,
obj: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
return _repository.set_value(self._data_key, obj)
def _remove(self) -> None:
_repository.remove(self._data_key)
def _last_modification_time(self) -> datetime:
if not _repository.has(self._data_key):
raise ValueError(f'No object(s) which id is {self._data_key} are stored before.')
time = _repository.get_last_modification_time(self._data_key)
return time
def _path(self) -> str:
# TODO: this module name `_path` migit not be appropriate
return self._data_key
def make_in_memory_target(target_key: str, task_lock_params: TaskLockParams) -> InMemoryTarget:
return InMemoryTarget(target_key, task_lock_params)
================================================
FILE: gokart/info.py
================================================
from __future__ import annotations
from logging import getLogger
from typing import Any
import luigi
from gokart.task import TaskOnKart
from gokart.tree.task_info import make_task_info_as_tree_str
logger = getLogger(__name__)
def make_tree_info(
task: TaskOnKart[Any],
indent: str = '',
last: bool = True,
details: bool = False,
abbr: bool = True,
visited_tasks: set[str] | None = None,
ignore_task_names: list[str] | None = None,
) -> str:
"""
Return a string representation of the tasks, their statuses/parameters in a dependency tree format
This function has moved to `gokart.tree.task_info.make_task_info_as_tree_str`.
This code is remained for backward compatibility.
Parameters
----------
- task: TaskOnKart
Root task.
- details: bool
Whether or not to output details.
- abbr: bool
Whether or not to simplify tasks information that has already appeared.
- ignore_task_names: list[str] | None
List of task names to ignore.
Returns
-------
- tree_info : str
Formatted task dependency tree.
"""
return make_task_info_as_tree_str(task=task, details=details, abbr=abbr, ignore_task_names=ignore_task_names)
class tree_info(TaskOnKart[Any]):
mode: luigi.StrParameter = luigi.StrParameter(default='', description='This must be in ["simple", "all"].')
output_path: luigi.StrParameter = luigi.StrParameter(default='tree.txt', description='Output file path.')
def output(self):
return self.make_target(self.output_path, use_unique_id=False)
================================================
FILE: gokart/mypy.py
================================================
"""Plugin that provides support for gokart.TaskOnKart.
This Code reuses the code from mypy.plugins.dataclasses
https://github.com/python/mypy/blob/0753e2a82dad35034e000609b6e8daa37238bfaa/mypy/plugins/dataclasses.py
"""
from __future__ import annotations
import re
import sys
import warnings
from collections.abc import Callable, Iterator
from dataclasses import dataclass
from enum import Enum
from typing import Any, Final, Literal
import luigi
from mypy.expandtype import expand_type
from mypy.nodes import (
ARG_NAMED,
ARG_NAMED_OPT,
ArgKind,
Argument,
AssignmentStmt,
Block,
CallExpr,
ClassDef,
EllipsisExpr,
Expression,
IfStmt,
JsonDict,
MemberExpr,
NameExpr,
PlaceholderNode,
RefExpr,
Statement,
TempNode,
TypeInfo,
Var,
)
from mypy.options import Options
from mypy.plugin import ClassDefContext, FunctionContext, Plugin, SemanticAnalyzerPluginInterface
from mypy.plugins.common import (
add_method_to_class,
deserialize_and_fixup_type,
)
from mypy.server.trigger import make_wildcard_trigger
from mypy.state import state
from mypy.typeops import map_type_from_supertype
from mypy.types import (
AnyType,
Instance,
NoneType,
Type,
TypeOfAny,
UnionType,
)
from mypy.typevars import fill_typevars
METADATA_TAG: Final[str] = 'task_on_kart'
PARAMETER_FULLNAME_MATCHER: Final = re.compile(r'^(gokart|luigi)(\.parameter)?\.\w*Parameter$')
PARAMETER_TMP_MATCHER: Final = re.compile(r'^\w*Parameter$')
class PluginOptions(Enum):
DISALLOW_MISSING_PARAMETERS = 'disallow_missing_parameters'
@dataclass
class TaskOnKartPluginOptions:
# Whether to error on missing parameters in the constructor.
# Some projects use luigi.Config to set parameters, which does not require parameters to be explicitly passed to the constructor.
disallow_missing_parameters: bool = False
@classmethod
def _parse_toml(cls, config_file: str) -> dict[str, Any]:
if sys.version_info >= (3, 11):
import tomllib as toml_
else:
try:
import tomli as toml_
except ImportError: # pragma: no cover
warnings.warn('install tomli to parse pyproject.toml under Python 3.10', stacklevel=1)
return {}
with open(config_file, 'rb') as f:
return toml_.load(f)
@classmethod
def parse_config_file(cls, config_file: str) -> TaskOnKartPluginOptions:
# TODO: support other configuration file formats if necessary.
if not config_file.endswith('.toml'):
warnings.warn('gokart mypy plugin can be configured by pyproject.toml', stacklevel=1)
return cls()
config = cls._parse_toml(config_file)
gokart_plugin_config = config.get('tool', {}).get('gokart-mypy', {})
disallow_missing_parameters = gokart_plugin_config.get(PluginOptions.DISALLOW_MISSING_PARAMETERS.value, False)
if not isinstance(disallow_missing_parameters, bool):
raise ValueError(f'{PluginOptions.DISALLOW_MISSING_PARAMETERS.value} must be a boolean value')
return cls(disallow_missing_parameters=disallow_missing_parameters)
class TaskOnKartPlugin(Plugin):
def __init__(self, options: Options) -> None:
super().__init__(options)
if options.config_file is not None:
self._options = TaskOnKartPluginOptions.parse_config_file(options.config_file)
else:
self._options = TaskOnKartPluginOptions()
def get_base_class_hook(self, fullname: str) -> Callable[[ClassDefContext], None] | None:
# The following gathers attributes from gokart.TaskOnKart such as `workspace_directory`
# the transformation does not affect because the class has `__init__` method of `gokart.TaskOnKart`.
#
# NOTE: `gokart.task.luigi.Task` condition is required for the release of luigi versions without py.typed
if fullname in {'gokart.task.luigi.Task', 'luigi.task.Task'}:
return self._task_on_kart_class_maker_callback
sym = self.lookup_fully_qualified(fullname)
if sym and isinstance(sym.node, TypeInfo):
if any(base.fullname == 'gokart.task.TaskOnKart' for base in sym.node.mro):
return self._task_on_kart_class_maker_callback
return None
def get_function_hook(self, fullname: str) -> Callable[[FunctionContext], Type] | None:
"""Adjust the return type of the `Parameters` function."""
if PARAMETER_FULLNAME_MATCHER.match(fullname):
return self._task_on_kart_parameter_field_callback
return None
def _task_on_kart_class_maker_callback(self, ctx: ClassDefContext) -> None:
transformer = TaskOnKartTransformer(ctx.cls, ctx.reason, ctx.api, self._options)
transformer.transform()
def _task_on_kart_parameter_field_callback(self, ctx: FunctionContext) -> Type:
"""Extract the type of the `default` argument from the Field function, and use it as the return type.
In particular:
* Retrieve the type of the argument which is specified, and use it as return type for the function.
* If no default argument is specified, return AnyType with unannotated type instead of parameter types like `luigi.Parameter()`
This makes mypy avoid conflict between the type annotation and the parameter type.
e.g.
```python
foo: int = luigi.IntParameter()
```
"""
try:
default_idx = ctx.callee_arg_names.index('default')
# if no `default` argument is found, return AnyType with unannotated type.
except ValueError:
return AnyType(TypeOfAny.unannotated)
default_args = ctx.args[default_idx]
if default_args:
default_type = ctx.arg_types[0][0]
default_arg = default_args[0]
# Fallback to default Any type if the field is required
if not isinstance(default_arg, EllipsisExpr):
return default_type
# NOTE: This is a workaround to avoid the error between type annotation and parameter type.
# As the following code snippet, the type of `foo` is `int` but the assigned value is `luigi.IntParameter()`.
# foo: int = luigi.IntParameter()
# TODO: infer mypy type from the parameter type.
return AnyType(TypeOfAny.unannotated)
class TaskOnKartAttribute:
def __init__(
self,
name: str,
has_default: bool,
line: int,
column: int,
type: Type | None,
info: TypeInfo,
api: SemanticAnalyzerPluginInterface,
options: TaskOnKartPluginOptions,
) -> None:
self.name = name
self.has_default = has_default
self.line = line
self.column = column
self.type = type # Type as __init__ argument
self.info = info
self._api = api
self._options = options
def to_argument(self, current_info: TypeInfo, *, of: Literal['__init__',]) -> Argument:
if of == '__init__':
arg_kind = self._get_arg_kind_by_options()
return Argument(
variable=self.to_var(current_info),
type_annotation=self.expand_type(current_info),
initializer=EllipsisExpr() if self.has_default else None, # Only used by stubgen
kind=arg_kind,
)
def expand_type(self, current_info: TypeInfo) -> Type | None:
if self.type is not None and self.info.self_type is not None:
# In general, it is not safe to call `expand_type()` during semantic analysis,
# however this plugin is called very late, so all types should be fully ready.
# Also, it is tricky to avoid eager expansion of Self types here (e.g. because
# we serialize attributes).
with state.strict_optional_set(self._api.options.strict_optional):
return expand_type(self.type, {self.info.self_type.id: fill_typevars(current_info)})
return self.type
def to_var(self, current_info: TypeInfo) -> Var:
return Var(self.name, self.expand_type(current_info))
def serialize(self) -> JsonDict:
assert self.type
return {
'name': self.name,
'has_default': self.has_default,
'line': self.line,
'column': self.column,
'type': self.type.serialize(),
}
@classmethod
def deserialize(cls, info: TypeInfo, data: JsonDict, api: SemanticAnalyzerPluginInterface, options: TaskOnKartPluginOptions) -> TaskOnKartAttribute:
data = data.copy()
typ = deserialize_and_fixup_type(data.pop('type'), api)
return cls(type=typ, info=info, **data, api=api, options=options)
def expand_typevar_from_subtype(self, sub_type: TypeInfo) -> None:
"""Expands type vars in the context of a subtype when an attribute is inherited
from a generic super type."""
if self.type is not None:
with state.strict_optional_set(self._api.options.strict_optional):
self.type = map_type_from_supertype(self.type, sub_type, self.info)
def _get_arg_kind_by_options(self) -> Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]:
"""Set the argument kind based on the options.
if `disallow_missing_parameters` is True, the argument kind is `ARG_NAMED` when the attribute has no default value.
This means the that all the parameters are passed to the constructor as keyword-only arguments.
Returns:
Literal[ArgKind.ARG_NAMED, ArgKind.ARG_NAMED_OPT]: The argument kind.
"""
if not self._options.disallow_missing_parameters:
return ARG_NAMED_OPT
if self.has_default:
return ARG_NAMED_OPT
# required parameter
return ARG_NAMED
class TaskOnKartTransformer:
"""Implement the behavior of gokart.TaskOnKart."""
def __init__(
self,
cls: ClassDef,
reason: Expression | Statement,
api: SemanticAnalyzerPluginInterface,
options: TaskOnKartPluginOptions,
) -> None:
self._cls = cls
self._reason = reason
self._api = api
self._options = options
def transform(self) -> bool:
"""Apply all the necessary transformations to the underlying gokart.TaskOnKart"""
info = self._cls.info
attributes = self.collect_attributes()
if attributes is None:
# Some definitions are not ready. We need another pass.
return False
for attr in attributes:
if attr.type is None:
return False
# If there are no attributes, it may be that the semantic analyzer has not
# processed them yet. In order to work around this, we can simply skip generating
# __init__ if there are no attributes, because if the user truly did not define any,
# then the object default __init__ with an empty signature will be present anyway.
if ('__init__' not in info.names or info.names['__init__'].plugin_generated) and attributes:
args = [attr.to_argument(info, of='__init__') for attr in attributes]
add_method_to_class(self._api, self._cls, '__init__', args=args, return_type=NoneType())
info.metadata[METADATA_TAG] = {
'attributes': [attr.serialize() for attr in attributes],
}
return True
def _get_assignment_statements_from_if_statement(self, stmt: IfStmt) -> Iterator[AssignmentStmt]:
for body in stmt.body:
if not body.is_unreachable:
yield from self._get_assignment_statements_from_block(body)
if stmt.else_body is not None and not stmt.else_body.is_unreachable:
yield from self._get_assignment_statements_from_block(stmt.else_body)
def _get_assignment_statements_from_block(self, block: Block) -> Iterator[AssignmentStmt]:
for stmt in block.body:
if isinstance(stmt, AssignmentStmt):
yield stmt
elif isinstance(stmt, IfStmt):
yield from self._get_assignment_statements_from_if_statement(stmt)
def collect_attributes(self) -> list[TaskOnKartAttribute] | None:
"""Collect all attributes declared in the task and its parents.
All assignments of the form
a: SomeType
b: SomeOtherType = ...
are collected.
Return None if some base class hasn't been processed
yet and thus we'll need to ask for another pass.
"""
cls = self._cls
# First, collect attributes belonging to any class in the MRO, ignoring duplicates.
#
# We iterate through the MRO in reverse because attrs defined in the parent must appear
# earlier in the attributes list than attrs defined in the child.
#
# However, we also want attributes defined in the subtype to override ones defined
# in the parent. We can implement this via a dict without disrupting the attr order
# because dicts preserve insertion order in Python 3.7+.
found_attrs: dict[str, TaskOnKartAttribute] = {}
for info in reversed(cls.info.mro[1:-1]):
if METADATA_TAG not in info.metadata:
continue
# Each class depends on the set of attributes in its task_on_kart ancestors.
self._api.add_plugin_dependency(make_wildcard_trigger(info.fullname))
for data in info.metadata[METADATA_TAG]['attributes']:
name: str = data['name']
attr = TaskOnKartAttribute.deserialize(info, data, self._api, self._options)
# TODO: We shouldn't be performing type operations during the main
# semantic analysis pass, since some TypeInfo attributes might
# still be in flux. This should be performed in a later phase.
attr.expand_typevar_from_subtype(cls.info)
found_attrs[name] = attr
sym_node = cls.info.names.get(name)
if sym_node and sym_node.node and not isinstance(sym_node.node, Var):
self._api.fail(
'TaskOnKart attribute may only be overridden by another attribute',
sym_node.node,
)
# Second, collect attributes belonging to the current class.
current_attr_names: set[str] = set()
for stmt in self._get_assignment_statements_from_block(cls.defs):
if not is_parameter_call(stmt.rvalue):
continue
# a: int, b: str = 1, 'foo' is not supported syntax so we
# don't have to worry about it.
lhs = stmt.lvalues[0]
if not isinstance(lhs, NameExpr):
continue
sym = cls.info.names.get(lhs.name)
if sym is None:
# There was probably a semantic analysis error.
continue
node = sym.node
assert not isinstance(node, PlaceholderNode)
assert isinstance(node, Var)
has_parameter_call, parameter_args = self._collect_parameter_args(stmt.rvalue)
has_default = False
# Ensure that something like x: int = field() is rejected
# after an attribute with a default.
if has_parameter_call:
has_default = 'default' in parameter_args
# All other assignments are already type checked.
elif not isinstance(stmt.rvalue, TempNode):
has_default = True
if not has_default:
# Make all non-default task_on_kart attributes implicit because they are de-facto
# set on self in the generated __init__(), not in the class body. On the other
# hand, we don't know how custom task_on_kart transforms initialize attributes,
# so we don't treat them as implicit. This is required to support descriptors
# (https://github.com/python/mypy/issues/14868).
sym.implicit = True
current_attr_names.add(lhs.name)
with state.strict_optional_set(self._api.options.strict_optional):
init_type = sym.type
# infer Parameter type
if init_type is None:
init_type = self._infer_type_from_parameters(stmt.rvalue)
found_attrs[lhs.name] = TaskOnKartAttribute(
name=lhs.name,
has_default=has_default,
line=stmt.line,
column=stmt.column,
type=init_type,
info=cls.info,
api=self._api,
options=self._options,
)
return list(found_attrs.values())
def _collect_parameter_args(self, expr: Expression) -> tuple[bool, dict[str, Expression]]:
"""Returns a tuple where the first value represents whether or not
the expression is a call to luigi.Parameter() or gokart.TaskInstanceParameter()
and the second value is a dictionary of the keyword arguments that luigi.Parameter() or gokart.TaskInstanceParameter() was called with.
"""
if isinstance(expr, CallExpr) and isinstance(expr.callee, RefExpr):
args = {}
for name, arg in zip(expr.arg_names, expr.args, strict=False):
if name is None:
# NOTE: this is a workaround to get default value from a parameter
self._api.fail(
'Positional arguments are not allowed for parameters when using the mypy plugin. '
"Update your code to use named arguments, like luigi.Parameter(default='foo') instead of luigi.Parameter('foo')",
expr,
)
continue
args[name] = arg
return True, args
return False, {}
def _infer_type_from_parameters(self, parameter: Expression) -> Type | None:
"""
Generate default type from Parameter.
For example, when parameter is `luigi.parameter.Parameter`, this method should return `str` type.
"""
parameter_name = _extract_parameter_name(parameter)
if parameter_name is None:
return None
underlying_type: Type | None = None
if parameter_name in ['luigi.parameter.Parameter', 'luigi.parameter.OptionalParameter']:
underlying_type = self._api.named_type('builtins.str', [])
elif parameter_name in ['luigi.parameter.IntParameter', 'luigi.parameter.OptionalIntParameter']:
underlying_type = self._api.named_type('builtins.int', [])
elif parameter_name in ['luigi.parameter.FloatParameter', 'luigi.parameter.OptionalFloatParameter']:
underlying_type = self._api.named_type('builtins.float', [])
elif parameter_name in ['luigi.parameter.BoolParameter', 'luigi.parameter.OptionalBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.DateParameter', 'luigi.parameter.MonthParameter', 'luigi.parameter.YearParameter']:
underlying_type = self._api.named_type('datetime.date', [])
elif parameter_name in ['luigi.parameter.DateHourParameter', 'luigi.parameter.DateMinuteParameter', 'luigi.parameter.DateSecondParameter']:
underlying_type = self._api.named_type('datetime.datetime', [])
elif parameter_name in ['luigi.parameter.TimeDeltaParameter']:
underlying_type = self._api.named_type('datetime.timedelta', [])
elif parameter_name in ['luigi.parameter.DictParameter', 'luigi.parameter.OptionalDictParameter']:
underlying_type = self._api.named_type('builtins.dict', [AnyType(TypeOfAny.unannotated), AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.ListParameter', 'luigi.parameter.OptionalListParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.TupleParameter', 'luigi.parameter.OptionalTupleParameter']:
underlying_type = self._api.named_type('builtins.tuple', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['luigi.parameter.PathParameter', 'luigi.parameter.OptionalPathParameter']:
underlying_type = self._api.named_type('pathlib.Path', [])
elif parameter_name in ['gokart.parameter.TaskInstanceParameter']:
underlying_type = self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])
elif parameter_name in ['gokart.parameter.ListTaskInstanceParameter']:
underlying_type = self._api.named_type('builtins.list', [self._api.named_type('gokart.task.TaskOnKart', [AnyType(TypeOfAny.unannotated)])])
elif parameter_name in ['gokart.parameter.ExplicitBoolParameter']:
underlying_type = self._api.named_type('builtins.bool', [])
elif parameter_name in ['luigi.parameter.NumericalParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceParameter']:
underlying_type = self._get_type_from_args(parameter, 'var_type')
elif parameter_name in ['luigi.parameter.ChoiceListParameter']:
base_type = self._get_type_from_args(parameter, 'var_type')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])
elif parameter_name in ['luigi.parameter.EnumParameter']:
underlying_type = self._get_type_from_args(parameter, 'enum')
elif parameter_name in ['luigi.parameter.EnumListParameter']:
base_type = self._get_type_from_args(parameter, 'enum')
if base_type is not None:
underlying_type = self._api.named_type('builtins.tuple', [base_type])
if underlying_type is None:
return None
# When parameter has Optional, it can be none value.
if 'Optional' in parameter_name:
return UnionType([underlying_type, NoneType()])
return underlying_type
def _get_type_from_args(self, parameter: Expression, arg_key: str) -> Type | None:
"""
get type from parameter arguments.
e.x)
When parameter is `luigi.ChoiceParameter(var_type=int)`, this method should return `int` type.
"""
ok, args = self._collect_parameter_args(parameter)
if not ok:
return None
if arg_key not in args:
return None
arg = args[arg_key]
if not isinstance(arg, NameExpr):
return None
if not isinstance(arg.node, TypeInfo):
return None
return Instance(arg.node, [])
def is_parameter_call(expr: Expression) -> bool:
"""Checks if the expression is a call to luigi.Parameter()"""
parameter_name = _extract_parameter_name(expr)
if parameter_name is None:
return False
return PARAMETER_FULLNAME_MATCHER.match(parameter_name) is not None
def _extract_parameter_name(expr: Expression) -> str | None:
"""Extract name if the expression is a call to luigi.Parameter()"""
if not isinstance(expr, CallExpr):
return None
callee = expr.callee
if isinstance(callee, MemberExpr):
type_info = callee.node
if type_info is None and isinstance(callee.expr, NameExpr):
return f'{callee.expr.name}.{callee.name}'
elif isinstance(callee, NameExpr):
type_info = callee.node
else:
return None
if isinstance(type_info, TypeInfo):
return type_info.fullname
# Currently, luigi doesn't provide py.typed. it will be released next to 3.5.1.
# https://github.com/spotify/luigi/pull/3297
# With the following code, we can't assume correctly.
#
# from luigi import Parameter
# class MyTask(gokart.TaskOnKart):
# param = Parameter()
if isinstance(type_info, Var) and luigi.__version__ <= '3.5.1':
return type_info.name
return None
def plugin(version: str) -> type[Plugin]:
return TaskOnKartPlugin
================================================
FILE: gokart/object_storage.py
================================================
from __future__ import annotations
from datetime import datetime
from typing import cast
import luigi
import luigi.contrib.gcs
import luigi.contrib.s3
from luigi.format import Format
from gokart.gcs_config import GCSConfig
from gokart.gcs_zip_client import GCSZipClient
from gokart.s3_config import S3Config
from gokart.s3_zip_client import S3ZipClient
from gokart.zip_client import ZipClient
object_storage_path_prefix = ['s3://', 'gs://']
class ObjectStorage:
@staticmethod
def if_object_storage_path(path: str) -> bool:
for prefix in object_storage_path_prefix:
if path.startswith(prefix):
return True
return False
@staticmethod
def get_object_storage_target(path: str, format: Format) -> luigi.target.FileSystemTarget:
if path.startswith('s3://'):
return luigi.contrib.s3.S3Target(path, client=S3Config().get_s3_client(), format=format)
elif path.startswith('gs://'):
return luigi.contrib.gcs.GCSTarget(path, client=GCSConfig().get_gcs_client(), format=format)
else:
raise
@staticmethod
def exists(path: str) -> bool:
if path.startswith('s3://'):
return cast(bool, S3Config().get_s3_client().exists(path))
elif path.startswith('gs://'):
return cast(bool, GCSConfig().get_gcs_client().exists(path))
else:
raise
@staticmethod
def get_timestamp(path: str) -> datetime:
if path.startswith('s3://'):
return cast(datetime, S3Config().get_s3_client().get_key(path).last_modified)
elif path.startswith('gs://'):
# for gcs object
# should PR to luigi
bucket, obj = GCSConfig().get_gcs_client()._path_to_bucket_and_key(path)
result = GCSConfig().get_gcs_client().client.objects().get(bucket=bucket, object=obj).execute()
return cast(datetime, result['updated'])
else:
raise
@staticmethod
def get_zip_client(file_path: str, temporary_directory: str) -> ZipClient:
if file_path.startswith('s3://'):
return S3ZipClient(file_path=file_path, temporary_directory=temporary_directory)
elif file_path.startswith('gs://'):
return GCSZipClient(file_path=file_path, temporary_directory=temporary_directory)
else:
raise
@staticmethod
def is_buffered_reader(file: object) -> bool:
return not isinstance(file, luigi.contrib.s3.ReadableS3File)
================================================
FILE: gokart/pandas_type_config.py
================================================
from __future__ import annotations
from abc import abstractmethod
from logging import getLogger
from typing import Any
import luigi
import numpy as np
import pandas as pd
from luigi.task_register import Register
logger = getLogger(__name__)
class PandasTypeError(Exception):
"""Raised when the type of the pandas DataFrame column is not as expected."""
class PandasTypeConfig(luigi.Config):
@classmethod
@abstractmethod
def type_dict(cls) -> dict[str, Any]:
pass
@classmethod
def check(cls, df: pd.DataFrame) -> None:
for column_name, column_type in cls.type_dict().items():
cls._check_column(df, column_name, column_type)
@classmethod
def _check_column(cls, df: pd.DataFrame, column_name: str, column_type: type) -> None:
if column_name not in df.columns:
return
if not np.all(list(map(lambda x: isinstance(x, column_type), df[column_name]))):
not_match = next(filter(lambda x: not isinstance(x, column_type), df[column_name]))
raise PandasTypeError(f'expected type is "{column_type}", but "{type(not_match)}" is passed in column "{column_name}".')
class PandasTypeConfigMap(luigi.Config):
"""To initialize this class only once, this inherits luigi.Config."""
def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
task_names = Register.task_names()
task_classes = [Register.get_task_cls(task_name) for task_name in task_names]
self._map = {
task_class.task_namespace: task_class for task_class in task_classes if issubclass(task_class, PandasTypeConfig) and task_class != PandasTypeConfig
}
def check(self, obj: Any, task_namespace: str) -> None:
if isinstance(obj, pd.DataFrame) and task_namespace in self._map:
self._map[task_namespace].check(obj)
================================================
FILE: gokart/parameter.py
================================================
from __future__ import annotations
import bz2
import datetime
import json
import sys
from logging import getLogger
from typing import Any, Generic, Protocol, TypeVar
if sys.version_info >= (3, 11):
from typing import Unpack
else:
from typing_extensions import Unpack
from warnings import warn
import luigi
from luigi import task_register
try:
from luigi.parameter import _no_value, _NoValueType, _ParameterKwargs
except ImportError:
_no_value = None # type: ignore[assignment]
_NoValueType = type(None) # type: ignore[assignment,misc]
_ParameterKwargs = dict # type: ignore[assignment,misc]
import gokart
logger = getLogger(__name__)
TASK_ON_KART_TYPE = TypeVar('TASK_ON_KART_TYPE', bound='gokart.TaskOnKart') # type: ignore
class TaskInstanceParameter(luigi.Parameter[TASK_ON_KART_TYPE], Generic[TASK_ON_KART_TYPE]):
def __init__(
self,
expected_type: type[TASK_ON_KART_TYPE] | None = None,
default: TASK_ON_KART_TYPE | _NoValueType = _no_value,
**kwargs: Unpack[_ParameterKwargs],
):
if expected_type is None:
self.expected_type: type = gokart.TaskOnKart
elif isinstance(expected_type, type):
self.expected_type = expected_type
else:
raise TypeError(f'expected_type must be a type, not {type(expected_type)}')
super().__init__(default=default, **kwargs)
@staticmethod
def _recursive(param_dict):
params = param_dict['params']
task_cls = task_register.Register.get_task_cls(param_dict['type'])
for key, value in task_cls.get_params():
if key in params:
params[key] = value.parse(params[key])
return task_cls(**params)
@staticmethod
def _recursive_decompress(s):
s = dict(luigi.DictParameter().parse(s))
if 'params' in s:
s['params'] = TaskInstanceParameter._recursive_decompress(bz2.decompress(bytes.fromhex(s['params'])).decode())
return s
def parse(self, s):
if isinstance(s, str):
s = self._recursive_decompress(s)
return self._recursive(s)
def serialize(self, x):
params = bz2.compress(json.dumps(x.to_str_params(only_significant=True)).encode()).hex()
values = dict(type=x.get_task_family(), params=params)
return luigi.DictParameter().serialize(values)
def _warn_on_wrong_param_type(self, param_name, param_value):
if not isinstance(param_value, self.expected_type):
raise TypeError(f'{param_value} is not an instance of {self.expected_type}')
class _TaskInstanceEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, luigi.Task):
return TaskInstanceParameter().serialize(obj)
# Let the base class default method raise the TypeError
return json.JSONEncoder.default(self, obj)
class ListTaskInstanceParameter(luigi.Parameter[list[TASK_ON_KART_TYPE]], Generic[TASK_ON_KART_TYPE]):
def __init__(
self,
expected_elements_type: type[TASK_ON_KART_TYPE] | None = None,
default: list[TASK_ON_KART_TYPE] | _NoValueType = _no_value,
**kwargs: Unpack[_ParameterKwargs],
):
if expected_elements_type is None:
self.expected_elements_type: type = gokart.TaskOnKart
elif isinstance(expected_elements_type, type):
self.expected_elements_type = expected_elements_type
else:
raise TypeError(f'expected_elements_type must be a type, not {type(expected_elements_type)}')
super().__init__(default=default, **kwargs)
def parse(self, s):
return [TaskInstanceParameter().parse(x) for x in list(json.loads(s))]
def serialize(self, x):
return json.dumps(x, cls=_TaskInstanceEncoder)
def _warn_on_wrong_param_type(self, param_name, param_value):
for v in param_value:
if not isinstance(v, self.expected_elements_type):
raise TypeError(f'{v} is not an instance of {self.expected_elements_type}')
class ExplicitBoolParameter(luigi.BoolParameter):
def __init__(self, *args, **kwargs):
luigi.Parameter.__init__(self, *args, **kwargs)
def _parser_kwargs(self, *args, **kwargs): # type: ignore
return luigi.Parameter._parser_kwargs(*args, *kwargs)
T = TypeVar('T')
class Serializable(Protocol):
def gokart_serialize(self) -> str:
"""Implement this method to serialize the object as an parameter
You can omit some fields from results of serialization if you want to ignore changes of them
"""
...
@classmethod
def gokart_deserialize(cls: type[T], s: str) -> T:
"""Implement this method to deserialize the object from a string"""
...
S = TypeVar('S', bound=Serializable)
class SerializableParameter(luigi.Parameter[S], Generic[S]):
def __init__(self, object_type: type[S], *args: Any, **kwargs: Any) -> None:
self._object_type = object_type
super().__init__(*args, **kwargs)
def parse(self, s: str) -> S:
return self._object_type.gokart_deserialize(s)
def serialize(self, x: S) -> str:
return x.gokart_serialize()
class ZonedDateSecondParameter(luigi.Parameter[datetime.datetime]):
"""
ZonedDateSecondParameter supports a datetime.datetime object with timezone information.
A ZonedDateSecondParameter is a `ISO 8601 `_ formatted
date, time specified to the second and timezone. For example, ``2013-07-10T19:07:38+09:00`` specifies July 10, 2013 at
19:07:38 +09:00. The separator `:` can be omitted for Python3.11 and later.
"""
def __init__(self, **kwargs):
super().__init__(**kwargs)
def parse(self, s):
# special character 'Z' is replaced with '+00:00'
# because Python 3.11 and later support fromisoformat with Z at the end of the string.
if s.endswith('Z'):
s = s[:-1] + '+00:00'
dt = datetime.datetime.fromisoformat(s)
if dt.tzinfo is None:
warn('The input does not have timezone information. Please consider using luigi.DateSecondParameter instead.', stacklevel=1)
return dt
def serialize(self, dt):
return dt.isoformat()
def normalize(self, dt):
# override _DatetimeParameterBase.normalize to avoid do nothing to normalize except removing microsecond.
# microsecond is removed because the number of digits of microsecond is not fixed.
# See also luigi's implementation https://github.com/spotify/luigi/blob/v3.6.0/luigi/parameter.py#L612
return dt.replace(microsecond=0)
================================================
FILE: gokart/py.typed
================================================
================================================
FILE: gokart/required_task_output.py
================================================
from dataclasses import dataclass
@dataclass
class RequiredTaskOutput:
task_name: str
output_path: str
def serialize(self) -> dict[str, str]:
return {'__gokart_task_name': self.task_name, '__gokart_output_path': self.output_path}
================================================
FILE: gokart/run.py
================================================
from __future__ import annotations
import logging
import os
import sys
from logging import getLogger
from typing import Any
import luigi
import luigi.cmdline
import luigi.cmdline_parser
import luigi.execution_summary
import luigi.interface
import luigi.retcodes
import luigi.setup_logging
from luigi.cmdline_parser import CmdlineParser
import gokart
import gokart.slack
from gokart.build import WorkerSchedulerFactory
from gokart.object_storage import ObjectStorage
logger = getLogger(__name__)
def _run_tree_info(cmdline_args, details):
with CmdlineParser.global_instance(cmdline_args) as cp:
gokart.tree_info().output().dump(gokart.make_tree_info(cp.get_task_obj(), details=details))
def _try_tree_info(cmdline_args):
with CmdlineParser.global_instance(cmdline_args):
mode = gokart.tree_info().mode
output_path = gokart.tree_info().output().path()
# do nothing if `mode` is empty.
if mode == '':
return
# output tree info and exit.
if mode == 'simple':
_run_tree_info(cmdline_args, details=False)
elif mode == 'all':
_run_tree_info(cmdline_args, details=True)
else:
raise ValueError(f'--tree-info-mode must be "simple" or "all", but "{mode}" is passed.')
logger.info(f'output tree info: {output_path}')
sys.exit()
def _try_to_delete_unnecessary_output_file(cmdline_args: list[str]) -> None:
with CmdlineParser.global_instance(cmdline_args) as cp:
task: gokart.TaskOnKart[Any] = cp.get_task_obj()
if task.delete_unnecessary_output_files:
if ObjectStorage.if_object_storage_path(task.workspace_directory):
logger.info('delete-unnecessary-output-files is not support s3/gcs.')
else:
gokart.delete_local_unnecessary_outputs(task)
sys.exit()
def _try_get_slack_api(cmdline_args: list[str]) -> gokart.slack.SlackAPI | None:
with CmdlineParser.global_instance(cmdline_args):
config = gokart.slack.SlackConfig()
token = os.getenv(config.token_name, '')
channel = config.channel
to_user = config.to_user
if token and channel:
logger.info('Slack notification is activated.')
return gokart.slack.SlackAPI(token=token, channel=channel, to_user=to_user)
logger.info('Slack notification is not activated.')
return None
def _try_to_send_event_summary_to_slack(
slack_api: gokart.slack.SlackAPI | None, event_aggregator: gokart.slack.EventAggregator, cmdline_args: list[str]
) -> None:
if slack_api is None:
# do nothing
return
options = gokart.slack.SlackConfig()
with CmdlineParser.global_instance(cmdline_args) as cp:
task = cp.get_task_obj()
tree_info = gokart.make_tree_info(task, details=True) if options.send_tree_info else 'Please add SlackConfig.send_tree_info to include tree-info'
task_name = type(task).__name__
comment = f'Report of {task_name}' + os.linesep + event_aggregator.get_summary()
content = os.linesep.join(['===== Event List ====', event_aggregator.get_event_list(), os.linesep, '==== Tree Info ====', tree_info])
slack_api.send_snippet(comment=comment, title='event.txt', content=content)
def _run_with_retcodes(argv):
"""run_with_retcodes equivalent that uses gokart's WorkerSchedulerFactory."""
retcode_logger = logging.getLogger('luigi-interface')
with luigi.cmdline_parser.CmdlineParser.global_instance(argv):
retcodes = luigi.retcodes.retcode()
worker = None
try:
worker = luigi.interface._run(argv, worker_scheduler_factory=WorkerSchedulerFactory()).worker
except luigi.interface.PidLockAlreadyTakenExit:
sys.exit(retcodes.already_running)
except Exception:
env_params = luigi.interface.core()
luigi.setup_logging.InterfaceLogging.setup(env_params)
retcode_logger.exception('Uncaught exception in luigi')
sys.exit(retcodes.unhandled_exception)
with luigi.cmdline_parser.CmdlineParser.global_instance(argv):
task_sets = luigi.execution_summary._summary_dict(worker)
root_task = luigi.execution_summary._root_task(worker)
non_empty_categories = {k: v for k, v in task_sets.items() if v}.keys()
def has(status):
assert status in luigi.execution_summary._ORDERED_STATUSES
return status in non_empty_categories
codes_and_conds = (
(retcodes.missing_data, has('still_pending_ext')),
(retcodes.task_failed, has('failed')),
(retcodes.already_running, has('run_by_other_worker')),
(retcodes.scheduling_error, has('scheduling_error')),
(retcodes.not_run, has('not_run')),
)
expected_ret_code = max(code * (1 if cond else 0) for code, cond in codes_and_conds)
if expected_ret_code == 0 and root_task not in task_sets['completed'] and root_task not in task_sets['already_done']:
sys.exit(retcodes.not_run)
else:
sys.exit(expected_ret_code)
def run(cmdline_args=None, set_retcode=True):
cmdline_args = cmdline_args or sys.argv[1:]
if set_retcode:
luigi.retcodes.retcode.already_running = 10 # type: ignore
luigi.retcodes.retcode.missing_data = 20 # type: ignore
luigi.retcodes.retcode.not_run = 30 # type: ignore
luigi.retcodes.retcode.task_failed = 40 # type: ignore
luigi.retcodes.retcode.scheduling_error = 50 # type: ignore
_try_tree_info(cmdline_args)
_try_to_delete_unnecessary_output_file(cmdline_args)
gokart.testing.try_to_run_test_for_empty_data_frame(cmdline_args)
slack_api = _try_get_slack_api(cmdline_args)
event_aggregator = gokart.slack.EventAggregator()
try:
event_aggregator.set_handlers()
_run_with_retcodes(cmdline_args)
except SystemExit as e:
_try_to_send_event_summary_to_slack(slack_api, event_aggregator, cmdline_args)
sys.exit(e.code)
================================================
FILE: gokart/s3_config.py
================================================
from __future__ import annotations
import os
import luigi
import luigi.contrib.s3
class S3Config(luigi.Config):
aws_access_key_id_name = luigi.Parameter(default='AWS_ACCESS_KEY_ID', description='AWS access key id environment variable.')
aws_secret_access_key_name = luigi.Parameter(default='AWS_SECRET_ACCESS_KEY', description='AWS secret access key environment variable.')
_client = None
def get_s3_client(self) -> luigi.contrib.s3.S3Client:
if self._client is None: # use cache as like singleton object
self._client = self._get_s3_client()
return self._client
def _get_s3_client(self) -> luigi.contrib.s3.S3Client:
return luigi.contrib.s3.S3Client(
aws_access_key_id=os.environ.get(self.aws_access_key_id_name), aws_secret_access_key=os.environ.get(self.aws_secret_access_key_name)
)
================================================
FILE: gokart/s3_zip_client.py
================================================
from __future__ import annotations
import os
import shutil
from typing import cast
from gokart.s3_config import S3Config
from gokart.zip_client import ZipClient, _unzip_file
class S3ZipClient(ZipClient):
def __init__(self, file_path: str, temporary_directory: str) -> None:
self._file_path = file_path
self._temporary_directory = temporary_directory
self._client = S3Config().get_s3_client()
def exists(self) -> bool:
return cast(bool, self._client.exists(self._file_path))
def make_archive(self) -> None:
extension = os.path.splitext(self._file_path)[1]
if not os.path.exists(self._temporary_directory):
# Check path existence since shutil.make_archive() of python 3.10+ does not check it.
raise FileNotFoundError(f'Temporary directory {self._temporary_directory} is not found.')
shutil.make_archive(base_name=self._temporary_directory, format=extension[1:], root_dir=self._temporary_directory)
self._client.put(self._temporary_file_path(), self._file_path)
def unpack_archive(self) -> None:
os.makedirs(self._temporary_directory, exist_ok=True)
self._client.get(self._file_path, self._temporary_file_path())
_unzip_file(fp=self._temporary_file_path(), extract_dir=self._temporary_directory)
def remove(self) -> None:
self._client.remove(self._file_path)
@property
def path(self) -> str:
return self._file_path
def _temporary_file_path(self):
extension = os.path.splitext(self._file_path)[1]
base_name = self._temporary_directory
if base_name.endswith('/'):
base_name = base_name[:-1]
return base_name + extension
================================================
FILE: gokart/slack/__init__.py
================================================
from gokart.slack.event_aggregator import EventAggregator
from gokart.slack.slack_api import SlackAPI
from gokart.slack.slack_config import SlackConfig
from .slack_api import ChannelListNotLoadedError, ChannelNotFoundError, FileNotUploadedError
__all__ = [
'ChannelListNotLoadedError',
'ChannelNotFoundError',
'FileNotUploadedError',
'EventAggregator',
'SlackAPI',
'SlackConfig',
]
================================================
FILE: gokart/slack/event_aggregator.py
================================================
from __future__ import annotations
import os
from logging import getLogger
from typing import Any, TypedDict
import luigi
logger = getLogger(__name__)
class FailureEvent(TypedDict):
task: str
exception: str
class EventAggregator:
def __init__(self) -> None:
self._success_events: list[str] = []
self._failure_events: list[FailureEvent] = []
def set_handlers(self):
handlers = [(luigi.Event.SUCCESS, self._success), (luigi.Event.FAILURE, self._failure)]
for event, handler in handlers:
luigi.Task.event_handler(event)(handler)
def get_summary(self) -> str:
return f'Success: {len(self._success_events)}; Failure: {len(self._failure_events)}'
def get_event_list(self) -> str:
message = ''
if len(self._failure_events) != 0:
failure_message = os.linesep.join([f'Task: {failure["task"]}; Exception: {failure["exception"]}' for failure in self._failure_events])
message += '---- Failure Tasks ----' + os.linesep + failure_message
if len(self._success_events) != 0:
success_message = os.linesep.join(self._success_events)
message += '---- Success Tasks ----' + os.linesep + success_message
if message == '':
message = 'Tasks were not executed.'
return message
def _success(self, task):
self._success_events.append(self._task_to_str(task))
def _failure(self, task, exception):
failure: FailureEvent = {'task': self._task_to_str(task), 'exception': str(exception)}
self._failure_events.append(failure)
@staticmethod
def _task_to_str(task: Any) -> str:
return f'{type(task).__name__}:[{task.make_unique_id()}]'
================================================
FILE: gokart/slack/slack_api.py
================================================
from __future__ import annotations
from logging import getLogger
import slack_sdk
logger = getLogger(__name__)
class ChannelListNotLoadedError(RuntimeError):
pass
class ChannelNotFoundError(RuntimeError):
pass
class FileNotUploadedError(RuntimeError):
pass
class SlackAPI:
def __init__(self, token: str, channel: str, to_user: str) -> None:
self._client = slack_sdk.WebClient(token=token)
self._channel_id = self._get_channel_id(channel)
self._to_user = to_user if to_user == '' or to_user.startswith('@') else '@' + to_user
def _get_channel_id(self, channel_name):
params = {'exclude_archived': True, 'limit': 100}
try:
for channels in self._client.conversations_list(params=params):
if not channels:
raise ChannelListNotLoadedError('Channel list is empty.')
for channel in channels.get('channels', []):
if channel['name'] == channel_name:
return channel['id']
raise ChannelNotFoundError(f'Channel {channel_name} is not found in public channels.')
except Exception as e:
logger.warning(f'The job will start without slack notification: {e}')
def send_snippet(self, comment, title, content):
try:
request_body = dict(
channels=self._channel_id, initial_comment=f'<{self._to_user}> {comment}' if self._to_user else comment, content=content, title=title
)
response = self._client.api_call('files.upload', data=request_body)
if not response['ok']:
raise FileNotUploadedError(f'Error while uploading file. The error reason is "{response["error"]}".')
except Exception as e:
logger.warning(f'Failed to send slack notification: {e}')
================================================
FILE: gokart/slack/slack_config.py
================================================
from __future__ import annotations
import luigi
class SlackConfig(luigi.Config):
token_name = luigi.Parameter(default='SLACK_TOKEN', description='slack token environment variable.')
channel = luigi.Parameter(default='', significant=False, description='channel name for notification.')
to_user = luigi.Parameter(default='', significant=False, description='Optional; user name who is supposed to be mentioned.')
send_tree_info = luigi.BoolParameter(
default=False,
significant=False,
description='When this option is true, the dependency tree of tasks is included in send message.'
'It is recommended to set false to this option when notification takes long time.',
)
================================================
FILE: gokart/target.py
================================================
from __future__ import annotations
import hashlib
import os
import shutil
from abc import abstractmethod
from datetime import datetime
from glob import glob
from logging import getLogger
from typing import Any, cast
import luigi
import numpy as np
import pandas as pd
from gokart.conflict_prevention_lock.task_lock import TaskLockParams, make_task_lock_params
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
from gokart.object_storage import ObjectStorage
from gokart.required_task_output import RequiredTaskOutput
from gokart.utils import FlattenableItems
from gokart.zip_client_util import make_zip_client
logger = getLogger(__name__)
class TargetOnKart(luigi.Target):
def exists(self) -> bool:
return self._exists()
def load(self) -> Any:
return wrap_load_with_lock(func=self._load, task_lock_params=self._get_task_lock_params())()
def dump(
self,
obj: Any,
lock_at_dump: bool = True,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
if lock_at_dump:
wrap_dump_with_lock(func=self._dump, task_lock_params=self._get_task_lock_params(), exist_check=self.exists)(
obj=obj,
task_params=task_params,
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)
else:
self._dump(obj=obj, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs)
def remove(self) -> None:
if self.exists():
wrap_remove_with_lock(self._remove, task_lock_params=self._get_task_lock_params())()
def last_modification_time(self) -> datetime:
return self._last_modification_time()
def path(self) -> str:
return self._path()
@abstractmethod
def _exists(self) -> bool:
pass
@abstractmethod
def _get_task_lock_params(self) -> TaskLockParams:
pass
@abstractmethod
def _load(self) -> Any:
pass
@abstractmethod
def _dump(
self,
obj: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
pass
@abstractmethod
def _remove(self) -> None:
pass
@abstractmethod
def _last_modification_time(self) -> datetime:
pass
@abstractmethod
def _path(self) -> str:
pass
class SingleFileTarget(TargetOnKart):
def __init__(
self,
target: luigi.target.FileSystemTarget,
processor: FileProcessor,
task_lock_params: TaskLockParams,
) -> None:
self._target = target
self._processor = processor
self._task_lock_params = task_lock_params
def _exists(self) -> bool:
return cast(bool, self._target.exists())
def _get_task_lock_params(self) -> TaskLockParams:
return self._task_lock_params
def _load(self) -> Any:
with self._target.open('r') as f:
return self._processor.load(f)
def _dump(
self,
obj: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
with self._target.open('w') as f:
self._processor.dump(obj, f)
if self.path().startswith('gs://'):
GCSObjectMetadataClient.add_task_state_labels(
path=self.path(), task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)
def _remove(self) -> None:
self._target.remove()
def _last_modification_time(self) -> datetime:
return _get_last_modification_time(self._target.path)
def _path(self) -> str:
return self._target.path
class ModelTarget(TargetOnKart):
def __init__(
self,
file_path: str,
temporary_directory: str,
load_function: Any,
save_function: Any,
task_lock_params: TaskLockParams,
) -> None:
self._zip_client = make_zip_client(file_path, temporary_directory)
self._temporary_directory = temporary_directory
self._save_function = save_function
self._load_function = load_function
self._task_lock_params = task_lock_params
def _exists(self) -> bool:
return self._zip_client.exists()
def _get_task_lock_params(self) -> TaskLockParams:
return self._task_lock_params
def _load(self) -> Any:
self._zip_client.unpack_archive()
self._load_function = self._load_function or make_target(self._load_function_path()).load()
model = self._load_function(self._model_path())
self._remove_temporary_directory()
return model
def _dump(
self,
obj: Any,
task_params: dict[str, str] | None = None,
custom_labels: dict[str, str] | None = None,
required_task_outputs: FlattenableItems[RequiredTaskOutput] | None = None,
) -> None:
self._make_temporary_directory()
self._save_function(obj, self._model_path())
make_target(self._load_function_path()).dump(
self._load_function, task_params=task_params, custom_labels=custom_labels, required_task_outputs=required_task_outputs
)
self._zip_client.make_archive()
self._remove_temporary_directory()
def _remove(self) -> None:
self._zip_client.remove()
def _last_modification_time(self) -> datetime:
return _get_last_modification_time(self._zip_client.path)
def _path(self) -> str:
return self._zip_client.path
def _model_path(self):
return os.path.join(self._temporary_directory, 'model.pkl')
def _load_function_path(self):
return os.path.join(self._temporary_directory, 'load_function.pkl')
def _remove_temporary_directory(self):
shutil.rmtree(self._temporary_directory)
def _make_temporary_directory(self):
os.makedirs(self._temporary_directory, exist_ok=True)
class LargeDataFrameProcessor:
def __init__(self, max_byte: int):
self.max_byte = int(max_byte)
def save(self, df: pd.DataFrame, file_path: str) -> None:
dir_path = os.path.dirname(file_path)
os.makedirs(dir_path, exist_ok=True)
if df.empty:
df.to_pickle(os.path.join(dir_path, 'data_0.pkl'))
return
split_size = df.values.nbytes // self.max_byte + 1
logger.info(f'saving a large pdDataFrame with split_size={split_size}')
for i, idx in list(enumerate(np.array_split(range(df.shape[0]), split_size))):
df.iloc[idx[0] : idx[-1] + 1].to_pickle(os.path.join(dir_path, f'data_{i}.pkl'))
@staticmethod
def load(file_path: str) -> pd.DataFrame:
dir_path = os.path.dirname(file_path)
return pd.concat([pd.read_pickle(file_path) for file_path in glob(os.path.join(dir_path, 'data_*.pkl'))])
def _make_file_system_target(file_path: str, processor: FileProcessor | None = None, store_index_in_feather: bool = True) -> luigi.target.FileSystemTarget:
processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather)
if ObjectStorage.if_object_storage_path(file_path):
return ObjectStorage.get_object_storage_target(file_path, processor.format())
return luigi.LocalTarget(file_path, format=processor.format())
def _make_file_path(original_path: str, unique_id: str | None = None) -> str:
if unique_id is not None:
[base, extension] = os.path.splitext(original_path)
return base + '_' + unique_id + extension
return original_path
def _get_last_modification_time(path: str) -> datetime:
if ObjectStorage.if_object_storage_path(path):
if ObjectStorage.exists(path):
return ObjectStorage.get_timestamp(path)
raise FileNotFoundError(f'No such file or directory: {path}')
return datetime.fromtimestamp(os.path.getmtime(path))
def make_target(
file_path: str,
unique_id: str | None = None,
processor: FileProcessor | None = None,
task_lock_params: TaskLockParams | None = None,
store_index_in_feather: bool = True,
) -> TargetOnKart:
_task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id)
file_path = _make_file_path(file_path, unique_id)
processor = processor or make_file_processor(file_path, store_index_in_feather=store_index_in_feather)
file_system_target = _make_file_system_target(file_path, processor=processor, store_index_in_feather=store_index_in_feather)
return SingleFileTarget(target=file_system_target, processor=processor, task_lock_params=_task_lock_params)
def make_model_target(
file_path: str,
temporary_directory: str,
save_function: Any,
load_function: Any,
unique_id: str | None = None,
task_lock_params: TaskLockParams | None = None,
) -> TargetOnKart:
_task_lock_params = task_lock_params if task_lock_params is not None else make_task_lock_params(file_path=file_path, unique_id=unique_id)
file_path = _make_file_path(file_path, unique_id)
temporary_directory = os.path.join(temporary_directory, hashlib.md5(file_path.encode()).hexdigest())
return ModelTarget(
file_path=file_path,
temporary_directory=temporary_directory,
save_function=save_function,
load_function=load_function,
task_lock_params=_task_lock_params,
)
================================================
FILE: gokart/task.py
================================================
from __future__ import annotations
import functools
import hashlib
import inspect
import os
import random
import types
from collections.abc import Callable, Generator, Iterable
from importlib import import_module
from logging import getLogger
from typing import Any, Generic, TypeVar, cast, overload
import luigi
import pandas as pd
from luigi.parameter import ParameterVisibility
import gokart
import gokart.target
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params, make_task_lock_params_for_run
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_run_with_lock
from gokart.file_processor import FileProcessor, make_file_processor
from gokart.pandas_type_config import PandasTypeConfigMap
from gokart.parameter import ExplicitBoolParameter, ListTaskInstanceParameter, TaskInstanceParameter
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart
from gokart.task_complete_check import task_complete_check_wrapper
from gokart.utils import FlattenableItems, flatten, get_dataframe_type_from_task, map_flattenable_items
logger = getLogger(__name__)
T = TypeVar('T')
K = TypeVar('K')
# NOTE: inherited from AssertionError for backward compatibility (Formerly, Gokart raises that exception when a task dumps an empty DataFrame).
class EmptyDumpError(AssertionError):
"""Raised when the task attempts to dump an empty DataFrame even though it is prohibited (``fail_on_empty_dump`` is set to True)"""
class TaskOnKart(luigi.Task, Generic[T]):
"""
This is a wrapper class of luigi.Task.
The key methods of a TaskOnKart are:
* :py:meth:`make_target` - this makes output target with a relative file path.
* :py:meth:`make_model_target` - this makes output target for models which generate multiple files to save.
* :py:meth:`load` - this loads input files of this task.
* :py:meth:`dump` - this save a object as output of this task.
"""
workspace_directory: luigi.Parameter[str] = luigi.Parameter(
default='./resources/', description='A directory to set outputs on. Please use a path starts with s3:// when you use s3.', significant=False
)
local_temporary_directory: luigi.Parameter[str] = luigi.Parameter(
default='./resources/tmp/', description='A directory to save temporary files.', significant=False
)
rerun: luigi.BoolParameter = luigi.BoolParameter(
default=False, description='If this is true, this task will run even if all output files exist.', significant=False
)
strict_check: luigi.BoolParameter = luigi.BoolParameter(
default=False, description='If this is true, this task will not run only if all input and output files exist.', significant=False
)
modification_time_check: luigi.BoolParameter = luigi.BoolParameter(
default=False,
description='If this is true, this task will not run only if all input and output files exist,'
' and all input files are modified before output file are modified.',
significant=False,
)
serialized_task_definition_check: luigi.BoolParameter = luigi.BoolParameter(
default=False,
description='If this is true, even if all outputs are present,this task will be executed if any changes have been made to the code.',
significant=False,
)
delete_unnecessary_output_files: luigi.BoolParameter = luigi.BoolParameter(
default=False, description='If this is true, delete unnecessary output files.', significant=False
)
significant: luigi.BoolParameter = luigi.BoolParameter(
default=True, description='If this is false, this task is not treated as a part of dependent tasks for the unique id.', significant=False
)
fix_random_seed_methods: luigi.Parameter[tuple[str, ...]] = luigi.ListParameter(
default=('random.seed', 'numpy.random.seed'), description='Fix random seed method list.', significant=False
)
FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER = -42497368
fix_random_seed_value: luigi.Parameter[int] = luigi.IntParameter(
default=FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER, description='Fix random seed method value.', significant=False
) # FIXME: should fix with OptionalIntParameter after newer luigi (https://github.com/spotify/luigi/pull/3079) will be released
redis_host: luigi.Parameter[str | None] = luigi.OptionalParameter(default=None, description='Task lock check is deactivated, when None.', significant=False)
redis_port: luigi.OptionalIntParameter = luigi.OptionalIntParameter(
default=None, # type: ignore
description='Task lock check is deactivated, when None.',
significant=False,
)
redis_timeout: luigi.IntParameter = luigi.IntParameter(
default=180, description='Redis lock will be released after `redis_timeout` seconds', significant=False
)
fail_on_empty_dump: luigi.Parameter[bool] = ExplicitBoolParameter(default=False, description='Fail when task dumps empty DF', significant=False)
store_index_in_feather: luigi.Parameter[bool] = ExplicitBoolParameter(
default=True, description='Wether to store index when using feather as a output object.', significant=False
)
cache_unique_id: luigi.Parameter[bool] = ExplicitBoolParameter(default=True, description='Cache unique id during runtime', significant=False)
should_dump_supplementary_log_files: luigi.Parameter[bool] = ExplicitBoolParameter(
default=True,
description='Whether to dump supplementary files (task_log, random_seed, task_params, processing_time, module_versions) or not. \
Note that when set to False, task_info functions (e.g. gokart.tree.task_info.make_task_info_as_tree_str()) cannot be used.',
significant=False,
)
complete_check_at_run: luigi.Parameter[bool] = ExplicitBoolParameter(
default=True, description='Check if output file exists at run. If exists, run() will be skipped.', significant=False
)
should_lock_run: luigi.Parameter[bool] = ExplicitBoolParameter(
default=False, significant=False, description='Whether to use redis lock or not at task run.'
)
@property
def priority(self):
return random.Random().random() # seed is fixed, so we need to use random.Random().random() instead f random.random()
def __init__(self, *args, **kwargs):
self._add_configuration(kwargs, 'TaskOnKart')
# 'This parameter is dumped into "workspace_directory/log/task_log/" when this task finishes with success.'
self.task_log = dict()
self.task_unique_id = None
super().__init__(*args, **kwargs)
self._rerun_state = self.rerun
self._lock_at_dump = True
# Cache to_str_params to avoid slow task creation in a deep task tree.
# For example, gokart.build(RecursiveTask(dep=RecursiveTask(dep=RecursiveTask(dep=HelloWorldTask())))) results in O(n^2) calls to to_str_params.
# However, @lru_cache cannot be used as a decorator because luigi.Task employs metaclass tricks.
self.to_str_params = functools.lru_cache(maxsize=None)(self.to_str_params) # type: ignore[method-assign]
if self.complete_check_at_run:
self.run = task_complete_check_wrapper(run_func=self.run, complete_check_func=self.complete) # type: ignore
if self.should_lock_run:
self._lock_at_dump = False
assert self.redis_host is not None, 'redis_host must be set when should_lock_run is True.'
assert self.redis_port is not None, 'redis_port must be set when should_lock_run is True.'
task_lock_params = make_task_lock_params_for_run(task_self=self)
self.run = wrap_run_with_lock(run_func=self.run, task_lock_params=task_lock_params) # type: ignore
def input(self) -> FlattenableItems[TargetOnKart]:
return cast(FlattenableItems[TargetOnKart], super().input())
def output(self) -> FlattenableItems[TargetOnKart]:
return self.make_target()
def requires(self) -> FlattenableItems[TaskOnKart[Any]]:
tasks = self.make_task_instance_dictionary()
return tasks or [] # when tasks is empty dict, then this returns empty list.
def make_task_instance_dictionary(self) -> dict[str, TaskOnKart[Any]]:
return {key: var for key, var in vars(self).items() if self.is_task_on_kart(var)}
@staticmethod
def is_task_on_kart(value):
return isinstance(value, TaskOnKart) or (isinstance(value, tuple) and bool(value) and all([isinstance(v, TaskOnKart) for v in value]))
@classmethod
def _add_configuration(cls, kwargs, section):
config = luigi.configuration.get_config()
class_variables = dict(TaskOnKart.__dict__)
class_variables.update(dict(cls.__dict__))
if section not in config:
return
for key, value in dict(config[section]).items():
if key not in kwargs and key in class_variables:
kwargs[key] = class_variables[key].parse(value)
def complete(self) -> bool:
if self._rerun_state:
for target in flatten(self.output()):
target.remove()
self._rerun_state = False
return False
is_completed = all([t.exists() for t in flatten(self.output())])
if self.strict_check or self.modification_time_check:
requirements = flatten(self.requires())
inputs = flatten(self.input())
is_completed = is_completed and all([task.complete() for task in requirements]) and all([i.exists() for i in inputs])
if not self.modification_time_check or not is_completed or not self.input():
return is_completed
return self._check_modification_time()
def _check_modification_time(self) -> bool:
common_path = set(t.path() for t in flatten(self.input())) & set(t.path() for t in flatten(self.output()))
input_tasks = [t for t in flatten(self.input()) if t.path() not in common_path]
output_tasks = [t for t in flatten(self.output()) if t.path() not in common_path]
input_modification_time = max([target.last_modification_time() for target in input_tasks]) if input_tasks else None
output_modification_time = min([target.last_modification_time() for target in output_tasks]) if output_tasks else None
if input_modification_time is None or output_modification_time is None:
return True
# "=" must be required in the following statements, because some tasks use input targets as output targets.
return input_modification_time <= output_modification_time
def clone(self, cls=None, **kwargs):
_SPECIAL_PARAMS = {'rerun', 'strict_check', 'modification_time_check'}
if cls is None:
cls = self.__class__
new_k = {}
for param_name, _ in cls.get_params():
if param_name in kwargs:
new_k[param_name] = kwargs[param_name]
elif hasattr(self, param_name) and (param_name not in _SPECIAL_PARAMS):
new_k[param_name] = getattr(self, param_name)
return cls(**new_k)
def make_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, processor: FileProcessor | None = None) -> TargetOnKart:
formatted_relative_file_path = (
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.pkl')
)
file_path = os.path.join(self.workspace_directory, formatted_relative_file_path)
unique_id = self.make_unique_id() if use_unique_id else None
# Auto-select processor based on type parameter if not provided
if processor is None and relative_file_path is not None:
processor = self._create_processor_for_dataframe_type(file_path)
task_lock_params = make_task_lock_params(
file_path=file_path,
unique_id=unique_id,
redis_host=self.redis_host,
redis_port=self.redis_port,
redis_timeout=self.redis_timeout,
raise_task_lock_exception_on_collision=False,
)
return gokart.target.make_target(
file_path=file_path, unique_id=unique_id, processor=processor, task_lock_params=task_lock_params, store_index_in_feather=self.store_index_in_feather
)
def _create_processor_for_dataframe_type(self, file_path: str) -> FileProcessor:
df_type = get_dataframe_type_from_task(self)
return make_file_processor(file_path, dataframe_type=df_type, store_index_in_feather=self.store_index_in_feather)
def make_large_data_frame_target(self, relative_file_path: str | None = None, use_unique_id: bool = True, max_byte: int = int(2**26)) -> TargetOnKart:
formatted_relative_file_path = (
relative_file_path if relative_file_path is not None else os.path.join(self.__module__.replace('.', '/'), f'{type(self).__name__}.zip')
)
file_path = os.path.join(self.workspace_directory, formatted_relative_file_path)
unique_id = self.make_unique_id() if use_unique_id else None
task_lock_params = make_task_lock_params(
file_path=file_path,
unique_id=unique_id,
redis_host=self.redis_host,
redis_port=self.redis_port,
redis_timeout=self.redis_timeout,
raise_task_lock_exception_on_collision=False,
)
return gokart.target.make_model_target(
file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
save_function=gokart.target.LargeDataFrameProcessor(max_byte=max_byte).save,
load_function=gokart.target.LargeDataFrameProcessor.load,
task_lock_params=task_lock_params,
)
def make_model_target(
self, relative_file_path: str, save_function: Callable[[Any, str], None], load_function: Callable[[str], Any], use_unique_id: bool = True
) -> TargetOnKart:
"""
Make target for models which generate multiple files in saving, e.g. gensim.Word2Vec, Tensorflow, and so on.
:param relative_file_path: A file path to save.
:param save_function: A function to save a model. This takes a model object and a file path.
:param load_function: A function to load a model. This takes a file path and returns a model object.
:param use_unique_id: If this is true, add an unique id to a file base name.
"""
file_path = os.path.join(self.workspace_directory, relative_file_path)
assert relative_file_path[-3:] == 'zip', f'extension must be zip, but {relative_file_path} is passed.'
unique_id = self.make_unique_id() if use_unique_id else None
task_lock_params = make_task_lock_params(
file_path=file_path,
unique_id=unique_id,
redis_host=self.redis_host,
redis_port=self.redis_port,
redis_timeout=self.redis_timeout,
raise_task_lock_exception_on_collision=False,
)
return gokart.target.make_model_target(
file_path=file_path,
temporary_directory=self.local_temporary_directory,
unique_id=unique_id,
save_function=save_function,
load_function=load_function,
task_lock_params=task_lock_params,
)
@overload
def load(self, target: None | str | TargetOnKart = None) -> Any: ...
@overload
def load(self, target: TaskOnKart[K]) -> K: ...
@overload
def load(self, target: list[TaskOnKart[K]]) -> list[K]: ...
def load(self, target: None | str | TargetOnKart | TaskOnKart[K] | list[TaskOnKart[K]] = None) -> Any:
def _load(targets):
if isinstance(targets, list) or isinstance(targets, tuple):
return [_load(t) for t in targets]
if isinstance(targets, dict):
return {k: _load(t) for k, t in targets.items()}
return targets.load()
return _load(self._get_input_targets(target))
@overload
def load_generator(self, target: None | str | TargetOnKart = None) -> Generator[Any, None, None]: ...
@overload
def load_generator(self, target: list[TaskOnKart[K]]) -> Generator[K, None, None]: ...
def load_generator(self, target: None | str | TargetOnKart | list[TaskOnKart[K]] = None) -> Generator[Any, None, None]:
def _load(targets):
if isinstance(targets, list) or isinstance(targets, tuple):
for t in targets:
yield from _load(t)
elif isinstance(targets, dict):
for k, t in targets.items():
yield from {k: _load(t)}
else:
yield targets.load()
return cast(Generator[Any, None, None], _load(self._get_input_targets(target)))
@overload
def dump(self, obj: T, target: None = None, custom_labels: dict[Any, Any] | None = None) -> None: ...
@overload
def dump(self, obj: Any, target: str | TargetOnKart, custom_labels: dict[Any, Any] | None = None) -> None: ...
def dump(self, obj: Any, target: None | str | TargetOnKart = None, custom_labels: dict[str, Any] | None = None) -> None:
PandasTypeConfigMap().check(obj, task_namespace=self.task_namespace)
if self.fail_on_empty_dump:
if isinstance(obj, pd.DataFrame) and obj.empty:
raise EmptyDumpError()
required_task_outputs = map_flattenable_items(
lambda task: map_flattenable_items(lambda output: RequiredTaskOutput(task_name=task.get_task_family(), output_path=output.path()), task.output()),
self.requires(),
)
self._get_output_target(target).dump(
obj,
lock_at_dump=self._lock_at_dump,
task_params=super().to_str_params(only_significant=True, only_public=True),
custom_labels=custom_labels,
required_task_outputs=required_task_outputs,
)
@staticmethod
def get_code(target_class: Any) -> set[str]:
def has_sourcecode(obj):
return inspect.ismethod(obj) or inspect.isfunction(obj) or inspect.isframe(obj) or inspect.iscode(obj)
return {inspect.getsource(t) for _, t in inspect.getmembers(target_class, has_sourcecode)}
def get_own_code(self):
gokart_codes = self.get_code(TaskOnKart)
own_codes = self.get_code(self)
return ''.join(sorted(list(own_codes - gokart_codes)))
def make_unique_id(self) -> str:
unique_id = self.task_unique_id or self._make_hash_id()
if self.cache_unique_id:
self.task_unique_id = unique_id
return unique_id
def _make_hash_id(self) -> str:
def _to_str_params(task):
if isinstance(task, TaskOnKart):
return str(task.make_unique_id()) if task.significant else None
if not isinstance(task, luigi.Task):
raise ValueError(f'Task.requires method returns {type(task)}. You should return luigi.Task.')
return task.to_str_params(only_significant=True)
dependencies = [_to_str_params(task) for task in flatten(self.requires())]
dependencies = [d for d in dependencies if d is not None]
dependencies.append(self.to_str_params(only_significant=True))
dependencies.append(self.__class__.__name__)
if self.serialized_task_definition_check:
dependencies.append(self.get_own_code())
return hashlib.md5(str(dependencies).encode()).hexdigest()
def _get_input_targets(self, target: None | str | TargetOnKart | TaskOnKart[Any] | list[TaskOnKart[Any]]) -> FlattenableItems[TargetOnKart]:
if target is None:
return self.input()
if isinstance(target, str):
input = self.input()
assert isinstance(input, dict), f'input must be dict[str, TargetOnKart], but {type(input)} is passed.'
result: FlattenableItems[TargetOnKart] = input[target]
return result
if isinstance(target, Iterable):
return [self._get_input_targets(t) for t in target]
if isinstance(target, TaskOnKart):
requires_unique_ids = [task.make_unique_id() for task in flatten(self.requires())]
assert target.make_unique_id() in requires_unique_ids, f'{target} should be in requires method'
return target.output()
return target
def _get_output_target(self, target: None | str | TargetOnKart) -> TargetOnKart:
if target is None:
output = self.output()
assert isinstance(output, TargetOnKart), f'output must be TargetOnKart, but {type(output)} is passed.'
return output
if isinstance(target, str):
output = self.output()
assert isinstance(output, dict), f'output must be dict[str, TargetOnKart], but {type(output)} is passed.'
result = output[target]
assert isinstance(result, TargetOnKart), f'output must be dict[str, TargetOnKart], but {type(output)} is passed.'
return result
return target
def get_info(self, only_significant=False):
params_str = {}
params = dict(self.get_params())
for param_name, param_value in self.param_kwargs.items():
if (not only_significant) or params[param_name].significant:
if isinstance(params[param_name], gokart.TaskInstanceParameter):
params_str[param_name] = type(param_value).__name__ + '-' + param_value.make_unique_id()
else:
params_str[param_name] = params[param_name].serialize(param_value)
return params_str
def _get_task_log_target(self):
return self.make_target(f'log/task_log/{type(self).__name__}.pkl')
def get_task_log(self) -> dict[str, Any]:
target = self._get_task_log_target()
if self.task_log:
return self.task_log
if target.exists():
return cast(dict[Any, Any], self.load(target))
return dict()
@luigi.Task.event_handler(luigi.Event.SUCCESS)
def _dump_task_log(self):
self.task_log['file_path'] = [target.path() for target in flatten(self.output())]
if self.should_dump_supplementary_log_files:
self.dump(self.task_log, self._get_task_log_target())
def _get_task_params_target(self):
return self.make_target(f'log/task_params/{type(self).__name__}.pkl')
def get_task_params(self) -> dict[str, Any]:
target = self._get_task_log_target()
if target.exists():
return cast(dict[Any, Any], self.load(target))
return dict()
@luigi.Task.event_handler(luigi.Event.START)
def _set_random_seed(self):
if self.should_dump_supplementary_log_files:
random_seed = self._get_random_seed()
seed_methods = self.try_set_seed(list(self.fix_random_seed_methods), random_seed)
self.dump({'seed': random_seed, 'seed_methods': seed_methods}, self._get_random_seeds_target())
def _get_random_seeds_target(self):
return self.make_target(f'log/random_seed/{type(self).__name__}.pkl')
@staticmethod
def try_set_seed(methods: list[str], random_seed: int) -> list[str]:
success_methods: list[str] = []
for method_name in methods:
try:
parts = method_name.split('.')
m: Any = import_module(parts[0])
for x in parts[1:]:
m = getattr(m, x)
m(random_seed)
success_methods.append(method_name)
except ModuleNotFoundError:
pass
except AttributeError:
pass
return success_methods
def _get_random_seed(self):
if self.fix_random_seed_value and (not self.fix_random_seed_value == self.FIX_RANDOM_SEED_VALUE_NONE_MAGIC_NUMBER):
return self.fix_random_seed_value
return int(self.make_unique_id(), 16) % (2**32 - 1) # maximum numpy.random.seed
@luigi.Task.event_handler(luigi.Event.START)
def _dump_task_params(self):
if self.should_dump_supplementary_log_files:
self.dump(self.to_str_params(only_significant=True), self._get_task_params_target())
def _get_processing_time_target(self):
return self.make_target(f'log/processing_time/{type(self).__name__}.pkl')
def get_processing_time(self) -> str:
target = self._get_processing_time_target()
if target.exists():
return cast(str, self.load(target))
return 'unknown'
@luigi.Task.event_handler(luigi.Event.PROCESSING_TIME)
def _dump_processing_time(self, processing_time):
if self.should_dump_supplementary_log_files:
self.dump(processing_time, self._get_processing_time_target())
@classmethod
def restore(cls, unique_id):
params = TaskOnKart().make_target(f'log/task_params/{cls.__name__}_{unique_id}.pkl', use_unique_id=False).load()
return cls.from_str_params(params)
@luigi.Task.event_handler(luigi.Event.FAILURE)
def _log_unique_id(self, exception):
logger.info(f'FAILURE:\n task name={type(self).__name__}\n unique id={self.make_unique_id()}')
@luigi.Task.event_handler(luigi.Event.START)
def _dump_module_versions(self):
if self.should_dump_supplementary_log_files:
self.dump(self._get_module_versions(), self._get_module_versions_target())
def _get_module_versions_target(self):
return self.make_target(f'log/module_versions/{type(self).__name__}.txt')
def _get_module_versions(self) -> str:
module_versions = []
for x in set([x.split('.')[0] for x in globals().keys() if isinstance(x, types.ModuleType) and '_' not in x]):
module = import_module(x)
if '__version__' in dir(module):
if isinstance(module.__version__, str):
version = module.__version__.split(' ')[0]
else:
version = '.'.join([str(v) for v in module.__version__])
module_versions.append(f'{x}=={version}')
return '\n'.join(module_versions)
def __repr__(self):
"""
Build a task representation like
`MyTask[aca2f28555dadd0f1e3dee3d4b973651](param1=1.5, param2='5', data_task=DataTask(c1f5d06aa580c5761c55bd83b18b0b4e))`
"""
return self._get_task_string()
def __str__(self):
"""
Build a human-readable task representation like
`MyTask[aca2f28555dadd0f1e3dee3d4b973651](param1=1.5, param2='5', data_task=DataTask(c1f5d06aa580c5761c55bd83b18b0b4e))`
This includes only public parameters
"""
return self._get_task_string(only_public=True)
def _get_task_string(self, only_public=False):
"""
Convert a task representation like `MyTask(param1=1.5, param2='5', data_task=DataTask(id=35tyi))`
"""
params = self.get_params()
param_values = self.get_param_values(params, [], self.param_kwargs)
# Build up task id
repr_parts = []
param_objs = dict(params)
for param_name, param_value in param_values:
param_obj = param_objs[param_name]
if param_obj.significant and ((not only_public) or param_obj.visibility == ParameterVisibility.PUBLIC):
repr_parts.append(f'{param_name}={self._make_representation(param_obj, param_value)}')
task_str = f'{self.get_task_family()}[{self.make_unique_id()}]({", ".join(repr_parts)})'
return task_str
def _make_representation(self, param_obj: luigi.Parameter, param_value: Any) -> str:
if isinstance(param_obj, TaskInstanceParameter):
return f'{param_value.get_task_family()}({param_value.make_unique_id()})'
if isinstance(param_obj, ListTaskInstanceParameter):
return f'[{", ".join(f"{v.get_task_family()}({v.make_unique_id()})" for v in param_value)}]'
return str(param_obj.serialize(param_value))
================================================
FILE: gokart/task_complete_check.py
================================================
from __future__ import annotations
import functools
from collections.abc import Callable
from logging import getLogger
from typing import Any
logger = getLogger(__name__)
def task_complete_check_wrapper(run_func: Callable[..., Any], complete_check_func: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(run_func)
def wrapper(*args, **kwargs):
if complete_check_func():
logger.warning(f'{run_func.__name__} is skipped because the task is already completed.')
return
return run_func(*args, **kwargs)
return wrapper
================================================
FILE: gokart/testing/__init__.py
================================================
__all__ = [
'test_run',
'try_to_run_test_for_empty_data_frame',
'assert_frame_contents_equal',
]
from gokart.testing.check_if_run_with_empty_data_frame import test_run, try_to_run_test_for_empty_data_frame
from gokart.testing.pandas_assert import assert_frame_contents_equal
================================================
FILE: gokart/testing/check_if_run_with_empty_data_frame.py
================================================
from __future__ import annotations
import logging
import sys
from typing import Any
import luigi
from luigi.cmdline_parser import CmdlineParser
import gokart
from gokart.utils import flatten
test_logger = logging.getLogger(__name__)
test_logger.addHandler(logging.StreamHandler())
test_logger.setLevel(logging.INFO)
class test_run(gokart.TaskOnKart[Any]):
pandas: luigi.BoolParameter = luigi.BoolParameter()
namespace: luigi.OptionalStrParameter = luigi.OptionalStrParameter(
default=None, description='When task namespace is not defined explicitly, please use "__not_user_specified".'
)
class _TestStatus:
def __init__(self, task: gokart.TaskOnKart[Any]) -> None:
self.namespace = task.task_namespace
self.name = type(task).__name__
self.task_id = task.make_unique_id()
self.status = 'OK'
self.message: Exception | None = None
def format(self) -> str:
s = f'status={self.status}; namespace={self.namespace}; name={self.name}; id={self.task_id};'
if self.message:
s += f' message={type(self.message)}: {", ".join(map(str, self.message.args))}'
return s
def fail(self) -> bool:
return self.status != 'OK'
def _get_all_tasks(task: gokart.TaskOnKart[Any]) -> list[gokart.TaskOnKart[Any]]:
result = [task]
for o in flatten(task.requires()):
result.extend(_get_all_tasks(o))
return result
def _run_with_test_status(task: gokart.TaskOnKart[Any]) -> _TestStatus:
test_message = _TestStatus(task)
try:
task.run()
except Exception as e:
test_message.status = 'NG'
test_message.message = e
return test_message
def _test_run_with_empty_data_frame(cmdline_args: list[str], test_run_params: test_run) -> None:
from unittest.mock import patch
try:
gokart.run(cmdline_args=cmdline_args)
except SystemExit as e:
assert e.code == 0, f'original workflow does not run properly. It exited with error code {e}.'
with CmdlineParser.global_instance(cmdline_args) as cp:
all_tasks = _get_all_tasks(cp.get_task_obj())
if test_run_params.namespace is not None:
all_tasks = [t for t in all_tasks if t.task_namespace == test_run_params.namespace]
with patch('gokart.TaskOnKart.dump', new=lambda *args, **kwargs: None):
test_status_list = [_run_with_test_status(t) for t in all_tasks]
test_logger.info('gokart test results:\n' + '\n'.join(s.format() for s in test_status_list))
if any(s.fail() for s in test_status_list):
sys.exit(1)
def try_to_run_test_for_empty_data_frame(cmdline_args: list[str]) -> None:
with CmdlineParser.global_instance(cmdline_args):
test_run_params = test_run()
if test_run_params.pandas:
cmdline_args = [a for a in cmdline_args if not a.startswith('--test-run-')]
_test_run_with_empty_data_frame(cmdline_args=cmdline_args, test_run_params=test_run_params)
sys.exit(0)
================================================
FILE: gokart/testing/pandas_assert.py
================================================
from __future__ import annotations
from typing import Any
import pandas as pd
def assert_frame_contents_equal(actual: pd.DataFrame, expected: pd.DataFrame, **kwargs: Any) -> None:
"""
Assert that two DataFrames are equal.
This function is mostly same as pandas.testing.assert_frame_equal, however
- this fuction ignores the order of index and columns.
- this function fails when duplicated index or columns are found.
Parameters
----------
- actual, expected: pd.DataFrame
DataFrames to be compared.
- kwargs: Any
Parameters passed to pandas.testing.assert_frame_equal.
"""
assert isinstance(actual, pd.DataFrame), 'actual is not a DataFrame'
assert isinstance(expected, pd.DataFrame), 'expected is not a DataFrame'
assert actual.index.is_unique, 'actual index is not unique'
assert expected.index.is_unique, 'expected index is not unique'
assert actual.columns.is_unique, 'actual columns is not unique'
assert expected.columns.is_unique, 'expected columns is not unique'
assert set(actual.columns) == set(expected.columns), 'columns are not equal'
assert set(actual.index) == set(expected.index), 'indexes are not equal'
expected_reindexed = expected.reindex(actual.index)[actual.columns]
pd.testing.assert_frame_equal(actual, expected_reindexed, **kwargs)
================================================
FILE: gokart/tree/task_info.py
================================================
from __future__ import annotations
import os
from typing import Any
import pandas as pd
from gokart.target import make_target
from gokart.task import TaskOnKart
from gokart.tree.task_info_formatter import make_task_info_tree, make_tree_info, make_tree_info_table_list
def make_task_info_as_tree_str(task: TaskOnKart[Any], details: bool = False, abbr: bool = True, ignore_task_names: list[str] | None = None) -> str:
"""
Return a string representation of the tasks, their statuses/parameters in a dependency tree format
Parameters
----------
- task: TaskOnKart
Root task.
- details: bool
Whether or not to output details.
- abbr: bool
Whether or not to simplify tasks information that has already appeared.
- ignore_task_names: list[str] | None
List of task names to ignore.
Returns
-------
- tree_info : str
Formatted task dependency tree.
"""
task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names)
result = make_tree_info(task_info=task_info, indent='', last=True, details=details, abbr=abbr, visited_tasks=set())
return result
def make_task_info_as_table(task: TaskOnKart[Any], ignore_task_names: list[str] | None = None) -> pd.DataFrame:
"""Return a table containing information about dependent tasks.
Parameters
----------
- task: TaskOnKart
Root task.
- ignore_task_names: list[str] | None
List of task names to ignore.
Returns
-------
- task_info_table : pandas.DataFrame
Formatted task dependency table.
"""
task_info = make_task_info_tree(task, ignore_task_names=ignore_task_names)
task_info_table = pd.DataFrame(make_tree_info_table_list(task_info=task_info, visited_tasks=set()))
return task_info_table
def dump_task_info_table(task: TaskOnKart[Any], task_info_dump_path: str, ignore_task_names: list[str] | None = None) -> None:
"""Dump a table containing information about dependent tasks.
Parameters
----------
- task: TaskOnKart
Root task.
- task_info_dump_path: str
Output target file path. Path destination can be `local`, `S3`, or `GCS`.
File extension can be any type that gokart file processor accepts, including `csv`, `pickle`, or `txt`.
See `TaskOnKart.make_target module ` for details.
- ignore_task_names: list[str] | None
List of task names to ignore.
Returns
-------
None
"""
task_info_table = make_task_info_as_table(task=task, ignore_task_names=ignore_task_names)
unique_id = task.make_unique_id()
task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id)
task_info_target.dump(obj=task_info_table, lock_at_dump=False)
def dump_task_info_tree(task: TaskOnKart[Any], task_info_dump_path: str, ignore_task_names: list[str] | None = None, use_unique_id: bool = True) -> None:
"""Dump the task info tree object (TaskInfo) to a pickle file.
Parameters
----------
- task: TaskOnKart
Root task.
- task_info_dump_path: str
Output target file path. Path destination can be `local`, `S3`, or `GCS`.
File extension must be '.pkl'.
- ignore_task_names: list[str] | None
List of task names to ignore.
- use_unique_id: bool = True
Whether to use unique id to dump target file. Default is True.
Returns
-------
None
"""
extension = os.path.splitext(task_info_dump_path)[1]
assert extension == '.pkl', f'File extention must be `.pkl`, not `{extension}`.'
task_info_tree = make_task_info_tree(task, ignore_task_names=ignore_task_names)
unique_id = task.make_unique_id() if use_unique_id else None
task_info_target = make_target(file_path=task_info_dump_path, unique_id=unique_id)
task_info_target.dump(obj=task_info_tree, lock_at_dump=False)
================================================
FILE: gokart/tree/task_info_formatter.py
================================================
from __future__ import annotations
import typing
import warnings
from dataclasses import dataclass
from typing import Any, NamedTuple
from gokart.task import TaskOnKart
from gokart.utils import FlattenableItems, flatten
@dataclass
class TaskInfo:
name: str
unique_id: str
output_paths: list[str]
params: dict[str, Any]
processing_time: str
is_complete: str
task_log: dict[str, Any]
requires: FlattenableItems[RequiredTask]
children_task_infos: list[TaskInfo]
def get_task_id(self):
return f'{self.name}_{self.unique_id}'
def get_task_title(self):
return f'({self.is_complete}) {self.name}[{self.unique_id}]'
def get_task_detail(self):
return f'(parameter={self.params}, output={self.output_paths}, time={self.processing_time}, task_log={self.task_log})'
def task_info_dict(self):
return dict(
name=self.name,
unique_id=self.unique_id,
output_paths=self.output_paths,
params=self.params,
processing_time=self.processing_time,
is_complete=self.is_complete,
task_log=self.task_log,
requires=self.requires,
)
class RequiredTask(NamedTuple):
name: str
unique_id: str
def _make_requires_info(requires):
if isinstance(requires, TaskOnKart):
return RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id())
elif isinstance(requires, dict):
return {key: _make_requires_info(requires=item) for key, item in requires.items()}
elif isinstance(requires, typing.Iterable):
return [_make_requires_info(requires=item) for item in requires]
raise TypeError(f'`requires` has unexpected type {type(requires)}. Must be `TaskOnKart`, `Iterarble[TaskOnKart]`, or `Dict[str, TaskOnKart]`')
def make_task_info_tree(task: TaskOnKart[Any], ignore_task_names: list[str] | None = None, cache: dict[str, TaskInfo] | None = None) -> TaskInfo:
with warnings.catch_warnings():
warnings.filterwarnings(action='ignore', message='Task .* without outputs has no custom complete() method')
is_task_complete = task.complete()
name = task.__class__.__name__
unique_id = task.make_unique_id()
output_paths: list[str] = [t.path() for t in flatten(task.output())]
cache = {} if cache is None else cache
cache_id = f'{name}_{unique_id}_{is_task_complete}'
if cache_id in cache:
return cache[cache_id]
params = task.get_info(only_significant=True)
processing_time = task.get_processing_time()
if isinstance(processing_time, float):
processing_time = str(processing_time) + 's'
is_complete = 'COMPLETE' if is_task_complete else 'PENDING'
task_log = dict(task.get_task_log())
requires = _make_requires_info(task.requires())
children = flatten(task.requires())
children_task_infos: list[TaskInfo] = []
for child in children:
if ignore_task_names is None or child.__class__.__name__ not in ignore_task_names:
children_task_infos.append(make_task_info_tree(child, ignore_task_names=ignore_task_names, cache=cache))
task_info = TaskInfo(
name=name,
unique_id=unique_id,
output_paths=output_paths,
params=params,
processing_time=processing_time,
is_complete=is_complete,
task_log=task_log,
requires=requires,
children_task_infos=children_task_infos,
)
cache[cache_id] = task_info
return task_info
def make_tree_info(task_info: TaskInfo, indent: str, last: bool, details: bool, abbr: bool, visited_tasks: set[str]) -> str:
result = '\n' + indent
if last:
result += '└─-'
indent += ' '
else:
result += '|--'
indent += '| '
result += task_info.get_task_title()
if abbr:
task_id = task_info.get_task_id()
if task_id not in visited_tasks:
visited_tasks.add(task_id)
else:
result += f'\n{indent}└─- ...'
return result
if details:
result += task_info.get_task_detail()
children = task_info.children_task_infos
for index, child in enumerate(children):
result += make_tree_info(child, indent, (index + 1) == len(children), details=details, abbr=abbr, visited_tasks=visited_tasks)
return result
def make_tree_info_table_list(task_info: TaskInfo, visited_tasks: set[str]) -> list[dict[str, typing.Any]]:
task_id = task_info.get_task_id()
if task_id in visited_tasks:
return []
visited_tasks.add(task_id)
result = [task_info.task_info_dict()]
children = task_info.children_task_infos
for child in children:
result += make_tree_info_table_list(task_info=child, visited_tasks=visited_tasks)
return result
================================================
FILE: gokart/utils.py
================================================
from __future__ import annotations
import os
from collections.abc import Callable, Iterable
from io import BytesIO
from typing import Any, Literal, Protocol, TypeAlias, TypeVar, get_args, get_origin
import dill
import luigi
import pandas as pd
class FileLike(Protocol):
def read(self, n: int) -> bytes: ...
def readline(self) -> bytes: ...
def seek(self, offset: int) -> None: ...
def seekable(self) -> bool: ...
def add_config(file_path: str) -> None:
_, ext = os.path.splitext(file_path)
luigi.configuration.core.parser = ext # type: ignore
assert luigi.configuration.add_config_path(file_path)
T = TypeVar('T')
FlattenableItems: TypeAlias = T | Iterable['FlattenableItems[T]'] | dict[str, 'FlattenableItems[T]']
def flatten(targets: FlattenableItems[T]) -> list[T]:
"""
Creates a flat list of all items in structured output (dicts, lists, items):
.. code-block:: python
>>> sorted(flatten({'a': 'foo', 'b': 'bar'}))
['bar', 'foo']
>>> sorted(flatten(['foo', ['bar', 'troll']]))
['bar', 'foo', 'troll']
>>> flatten('foo')
['foo']
>>> flatten(42)
[42]
This method is copied and modified from [luigi.task.flatten](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/luigi/task.py#L958) in accordance with [Apache License 2.0](https://github.com/spotify/luigi/blob/367edc2e3a099b8a0c2d15b1676269e33ad06117/LICENSE).
"""
if targets is None:
return []
flat = []
if isinstance(targets, dict):
for _, result in targets.items():
flat += flatten(result)
return flat
if isinstance(targets, str):
return [targets] # type: ignore
if not isinstance(targets, Iterable):
return [targets]
for result in targets:
flat += flatten(result)
return flat
K = TypeVar('K')
def map_flattenable_items(func: Callable[[T], K], items: FlattenableItems[T]) -> FlattenableItems[K]:
if isinstance(items, dict):
return {k: map_flattenable_items(func, v) for k, v in items.items()}
if isinstance(items, tuple):
return tuple(map_flattenable_items(func, i) for i in items)
if isinstance(items, str):
return func(items) # type: ignore
if isinstance(items, Iterable):
return list(map(lambda item: map_flattenable_items(func, item), items))
return func(items)
def load_dill_with_pandas_backward_compatibility(file: FileLike | BytesIO) -> Any:
"""Load binary dumped by dill with pandas backward compatibility.
pd.read_pickle can load binary dumped in backward pandas version, and also any objects dumped by pickle.
It is unclear whether all objects dumped by dill can be loaded by pd.read_pickle, we use dill.load as a fallback.
"""
try:
return dill.load(file)
except Exception:
assert file.seekable(), f'{file} is not seekable.'
file.seek(0)
return pd.read_pickle(file)
def get_dataframe_type_from_task(task: Any) -> Literal['pandas', 'polars', 'polars-lazy']:
"""
Extract DataFrame type from TaskOnKart[T] type parameter.
Examines the type parameter T of a TaskOnKart subclass to determine
whether it uses pandas or polars DataFrames/LazyFrames.
Args:
task: A TaskOnKart instance or class
Returns:
'pandas', 'polars', or 'polars-lazy' (defaults to 'pandas' if type cannot be determined)
Examples:
>>> class MyTask(TaskOnKart[pd.DataFrame]): pass
>>> get_dataframe_type_from_task(MyTask())
'pandas'
>>> class MyPolarsTask(TaskOnKart[pl.DataFrame]): pass
>>> get_dataframe_type_from_task(MyPolarsTask())
'polars'
"""
task_class = task if isinstance(task, type) else task.__class__
# Walk the MRO to find TaskOnKart[...] even when defined on a parent class
mro = task_class.mro() if hasattr(task_class, 'mro') else [task_class]
for cls in mro:
for base in getattr(cls, '__orig_bases__', ()):
origin = get_origin(base)
if origin and hasattr(origin, '__name__') and origin.__name__ == 'TaskOnKart':
args = get_args(base)
if not args:
continue
df_type = args[0]
module = getattr(df_type, '__module__', '')
# Check module name to determine DataFrame type
if 'polars' in module:
name = getattr(df_type, '__name__', '')
if name == 'LazyFrame':
return 'polars-lazy'
return 'polars'
elif 'pandas' in module:
return 'pandas'
return 'pandas' # Default to pandas for backward compatibility
================================================
FILE: gokart/worker.py
================================================
#
# Copyright 2012-2015 Spotify AB
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
"""
The worker communicates with the scheduler and does two things:
1. Sends all tasks that has to be run
2. Gets tasks from the scheduler that should be run
When running in local mode, the worker talks directly to a :py:class:`~luigi.scheduler.Scheduler` instance.
When you run a central server, the worker will talk to the scheduler using a :py:class:`~luigi.rpc.RemoteScheduler` instance.
Everything in this module is private to luigi and may change in incompatible
ways between versions. The exception is the exception types and the
:py:class:`worker` config class.
"""
from __future__ import annotations
import collections
import collections.abc
import contextlib
import datetime
import functools
import getpass
import importlib
import json
import logging
import multiprocessing
import os
import queue as Queue
import random
import signal
import socket
import subprocess
import sys
import threading
import time
import traceback
from collections.abc import Generator
from typing import Any, Literal, cast
import luigi
import luigi.scheduler
import luigi.worker
from luigi import notifications
from luigi.event import Event
from luigi.scheduler import DISABLED, DONE, FAILED, PENDING, UNKNOWN, WORKER_STATE_ACTIVE, WORKER_STATE_DISABLED, RetryPolicy, Scheduler
from luigi.target import Target
from luigi.task import DynamicRequirements, Task, flatten
from luigi.task_register import TaskClassException, load_task
from luigi.task_status import RUNNING
from gokart.parameter import ExplicitBoolParameter
logger = logging.getLogger(__name__)
# Use fork context instead of the default (spawn on macOS), which ensures compatibility with gokart's multiprocessing requirements.
_fork_context = multiprocessing.get_context('fork')
_ForkProcess = _fork_context.Process
# Prevent fork() from being called during a C-level getaddrinfo() which uses a process-global mutex,
# that may not be unlocked in child process, resulting in the process being locked indefinitely.
fork_lock = threading.Lock()
# Why we assert on _WAIT_INTERVAL_EPS:
# multiprocessing.Queue.get() is undefined for timeout=0 it seems:
# https://docs.python.org/3.4/library/multiprocessing.html#multiprocessing.Queue.get.
# I also tried with really low epsilon, but then ran into the same issue where
# the test case "test_external_dependency_worker_is_patient" got stuck. So I
# unscientifically just set the final value to a floating point number that
# "worked for me".
_WAIT_INTERVAL_EPS = 0.00001
def _is_external(task: Task) -> bool:
return task.run is None or task.run == NotImplemented
def _get_retry_policy_dict(task: Task) -> dict[str, Any]:
return RetryPolicy(task.retry_count, task.disable_hard_timeout, task.disable_window)._asdict() # type: ignore
GetWorkResponse = collections.namedtuple(
'GetWorkResponse',
(
'task_id',
'running_tasks',
'n_pending_tasks',
'n_unique_pending',
'n_pending_last_scheduled',
'worker_state',
),
)
class TaskProcess(_ForkProcess): # type: ignore[valid-type, misc]
"""Wrap all task execution in this class.
Mainly for convenience since this is run in a separate process."""
# mapping of status_reporter attributes to task attributes that are added to tasks
# before they actually run, and removed afterwards
forward_reporter_attributes = {
'update_tracking_url': 'set_tracking_url',
'update_status_message': 'set_status_message',
'update_progress_percentage': 'set_progress_percentage',
'decrease_running_resources': 'decrease_running_resources',
'scheduler_messages': 'scheduler_messages',
}
def __init__(
self,
task: luigi.Task,
worker_id: str,
result_queue: multiprocessing.Queue[Any],
status_reporter: luigi.worker.TaskStatusReporter,
use_multiprocessing: bool = False,
worker_timeout: int = 0,
check_unfulfilled_deps: bool = True,
check_complete_on_run: bool = False,
task_completion_cache: dict[str, Any] | None = None,
task_completion_check_at_run: bool = True,
) -> None:
super().__init__()
self.task = task
self.worker_id = worker_id
self.result_queue = result_queue
self.status_reporter = status_reporter
self.worker_timeout = task.worker_timeout if task.worker_timeout is not None else worker_timeout
self.timeout_time = time.time() + self.worker_timeout if self.worker_timeout else None
self.use_multiprocessing = use_multiprocessing or self.timeout_time is not None
self.check_unfulfilled_deps = check_unfulfilled_deps
self.check_complete_on_run = check_complete_on_run
self.task_completion_cache = task_completion_cache
self.task_completion_check_at_run = task_completion_check_at_run
# completeness check using the cache
self.check_complete = functools.partial(luigi.worker.check_complete_cached, completion_cache=task_completion_cache)
def _run_task(self) -> collections.abc.Generator[Any, Any, Any] | None:
if self.task_completion_check_at_run and self.check_complete(self.task):
logger.warning(f'{self.task} is skipped because the task is already completed.')
return None
return cast(collections.abc.Generator[Any, Any, Any] | None, self.task.run())
def _run_get_new_deps(self) -> list[tuple[str, str, dict[str, str]]] | None:
task_gen = self._run_task()
if not isinstance(task_gen, collections.abc.Generator):
return None
next_send = None
while True:
try:
if next_send is None:
requires = next(task_gen)
else:
requires = task_gen.send(next_send)
except StopIteration:
return None
# if requires is not a DynamicRequirements, create one to use its default behavior
if not isinstance(requires, DynamicRequirements):
requires = DynamicRequirements(requires)
if not requires.complete(self.check_complete):
# not all requirements are complete, return them which adds them to the tree
new_deps = [(t.task_module, t.task_family, t.to_str_params()) for t in requires.flat_requirements]
return new_deps
# get the next generator result
next_send = requires.paths
def run(self) -> None:
logger.info('[pid %s] Worker %s running %s', os.getpid(), self.worker_id, self.task)
if self.use_multiprocessing:
# Need to have different random seeds if running in separate processes
processID = os.getpid()
currentTime = time.time()
random.seed(processID * currentTime)
status: str | None = FAILED
expl = ''
missing: list[str] = []
new_deps: list[tuple[str, str, dict[str, str]]] | None = []
try:
# Verify that all the tasks are fulfilled! For external tasks we
# don't care about unfulfilled dependencies, because we are just
# checking completeness of self.task so outputs of dependencies are
# irrelevant.
if self.check_unfulfilled_deps and not _is_external(self.task):
missing = []
for dep in self.task.deps():
if not self.check_complete(dep):
nonexistent_outputs = [output for output in flatten(dep.output()) if not output.exists()]
if nonexistent_outputs:
missing.append(f'{dep.task_id} ({", ".join(map(str, nonexistent_outputs))})')
else:
missing.append(dep.task_id)
if missing:
deps = 'dependency' if len(missing) == 1 else 'dependencies'
raise RuntimeError('Unfulfilled {} at run time: {}'.format(deps, ', '.join(missing)))
self.task.trigger_event(Event.START, self.task)
t0 = time.time()
status = None
if _is_external(self.task):
# External task
if self.check_complete(self.task):
status = DONE
else:
status = FAILED
expl = 'Task is an external data dependency and data does not exist (yet?).'
else:
with self._forward_attributes():
new_deps = self._run_get_new_deps()
if not new_deps:
if not self.check_complete_on_run:
# update the cache
if self.task_completion_cache is not None:
self.task_completion_cache[self.task.task_id] = True
status = DONE
elif self.check_complete(self.task):
status = DONE
else:
raise luigi.worker.TaskException('Task finished running, but complete() is still returning false.')
else:
status = PENDING
if new_deps:
logger.info('[pid %s] Worker %s new requirements %s', os.getpid(), self.worker_id, self.task)
elif status == DONE:
self.task.trigger_event(Event.PROCESSING_TIME, self.task, time.time() - t0)
expl = self.task.on_success()
logger.info('[pid %s] Worker %s done %s', os.getpid(), self.worker_id, self.task)
self.task.trigger_event(Event.SUCCESS, self.task)
except KeyboardInterrupt:
raise
except BaseException as ex:
status = FAILED
expl = self._handle_run_exception(ex)
finally:
self.result_queue.put((self.task.task_id, status, expl, missing, new_deps))
def _handle_run_exception(self, ex: BaseException) -> str:
logger.exception('[pid %s] Worker %s failed %s', os.getpid(), self.worker_id, self.task)
self.task.trigger_event(Event.FAILURE, self.task, ex)
return cast(str, self.task.on_failure(ex))
def _recursive_terminate(self) -> None:
import psutil
try:
parent = psutil.Process(self.pid)
children = parent.children(recursive=True)
# terminate parent. Give it a chance to clean up
super().terminate()
parent.wait()
# terminate children
for child in children:
try:
child.terminate()
except psutil.NoSuchProcess:
continue
except psutil.NoSuchProcess:
return
def terminate(self) -> None:
"""Terminate this process and its subprocesses."""
# default terminate() doesn't cleanup child processes, it orphans them.
try:
return self._recursive_terminate()
except ImportError:
super().terminate()
@contextlib.contextmanager
def _forward_attributes(self):
# forward configured attributes to the task
for reporter_attr, task_attr in self.forward_reporter_attributes.items():
setattr(self.task, task_attr, getattr(self.status_reporter, reporter_attr))
try:
yield self
finally:
# reset attributes again
for _, task_attr in self.forward_reporter_attributes.items():
setattr(self.task, task_attr, None)
# This code and the task_process_context config key currently feels a bit ad-hoc.
# Discussion on generalizing it into a plugin system: https://github.com/spotify/luigi/issues/1897
class ContextManagedTaskProcess(TaskProcess):
def __init__(self, context: Any, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.context = context
def run(self) -> None:
if self.context:
logger.debug('Importing module and instantiating ' + self.context)
module_path, class_name = self.context.rsplit('.', 1)
module = importlib.import_module(module_path)
cls = getattr(module, class_name)
with cls(self):
super().run()
else:
super().run()
class gokart_worker(luigi.Config):
"""Configuration for the gokart worker.
You can set these options of section [gokart_worker] in your luigi.cfg file.
NOTE: use snake_case for this class to match the luigi.Config convention.
"""
id: luigi.StrParameter = luigi.StrParameter(default='', description='Override the auto-generated worker_id')
ping_interval: luigi.FloatParameter = luigi.FloatParameter(
default=1.0,
config_path=dict(section='core', name='worker-ping-interval'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
)
keep_alive: luigi.BoolParameter = luigi.BoolParameter(
default=False,
config_path=dict(section='core', name='worker-keep-alive'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
)
count_uniques: luigi.BoolParameter = luigi.BoolParameter(
default=False,
config_path=dict(section='core', name='worker-count-uniques'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
description='worker-count-uniques means that we will keep a worker alive only if it has a unique pending task, as well as having keep-alive true',
)
count_last_scheduled: luigi.BoolParameter = luigi.BoolParameter(
default=False, description='Keep a worker alive only if there are pending tasks which it was the last to schedule.'
)
wait_interval: luigi.FloatParameter = luigi.FloatParameter(
default=1.0,
config_path=dict(section='core', name='worker-wait-interval'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
)
wait_jitter: luigi.FloatParameter = luigi.FloatParameter(default=5.0)
max_keep_alive_idle_duration: luigi.TimeDeltaParameter = luigi.TimeDeltaParameter(default=datetime.timedelta(0))
max_reschedules: luigi.IntParameter = luigi.IntParameter(
default=1,
config_path=dict(section='core', name='worker-max-reschedules'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
)
timeout: luigi.IntParameter = luigi.IntParameter(
default=0,
config_path=dict(section='core', name='worker-timeout'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
)
task_limit: luigi.OptionalIntParameter = luigi.OptionalIntParameter(
default=None, # type: ignore[arg-type] # OptionalIntParameter.__init__ inherits IntParameter's signature
config_path=dict(section='core', name='worker-task-limit'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
)
retry_external_tasks: luigi.BoolParameter = luigi.BoolParameter(
default=False,
config_path=dict(section='core', name='retry-external-tasks'), # type: ignore # fix https://github.com/spotify/luigi/pull/3403
description='If true, incomplete external tasks will be retested for completion while Luigi is running.',
)
send_failure_email: luigi.BoolParameter = luigi.BoolParameter(default=True, description='If true, send e-mails directly from the workeron failure')
no_install_shutdown_handler: luigi.BoolParameter = luigi.BoolParameter(
default=False, description='If true, the SIGUSR1 shutdown handler willNOT be install on the worker'
)
check_unfulfilled_deps: luigi.BoolParameter = luigi.BoolParameter(
default=True, description='If true, check for completeness of dependencies before running a task'
)
check_complete_on_run: luigi.BoolParameter = luigi.BoolParameter(
default=False,
description='If true, only mark tasks as done after running if they are complete. '
'Regardless of this setting, the worker will always check if external '
'tasks are complete before marking them as done.',
)
force_multiprocessing: luigi.BoolParameter = luigi.BoolParameter(default=False, description='If true, use multiprocessing also when running with 1 worker')
task_process_context: luigi.OptionalStrParameter = luigi.OptionalStrParameter(
default=None,
description='If set to a fully qualified class name, the class will '
'be instantiated with a TaskProcess as its constructor parameter and '
'applied as a context manager around its run() call, so this can be '
'used for obtaining high level customizable monitoring or logging of '
'each individual Task run.',
)
cache_task_completion: luigi.BoolParameter = luigi.BoolParameter(
default=False,
description='If true, cache the response of successful completion checks '
'of tasks assigned to a worker. This can especially speed up tasks with '
'dynamic dependencies but assumes that the completion status does not change '
'after it was true the first time.',
)
task_completion_check_at_run: luigi.BoolParameter = ExplicitBoolParameter(
default=True, description='If true, tasks completeness will be re-checked just before the run, in case they are finished elsewhere.'
)
class Worker:
"""
Worker object communicates with a scheduler.
Simple class that talks to a scheduler and:
* tells the scheduler what it has to do + its dependencies
* asks for stuff to do (pulls it in a loop and runs it)
"""
def __init__(
self,
scheduler: Scheduler | None = None,
worker_id: str | None = None,
worker_processes: int = 1,
assistant: bool = False,
config: gokart_worker | None = None,
) -> None:
if scheduler is None:
scheduler = Scheduler()
self.worker_processes = int(worker_processes)
self._worker_info = self._generate_worker_info()
if config is None:
self._config = gokart_worker()
else:
self._config = config
worker_id = worker_id or self._config.id or self._generate_worker_id(self._worker_info)
assert self._config.wait_interval >= _WAIT_INTERVAL_EPS, '[worker] wait_interval must be positive'
assert self._config.wait_jitter >= 0.0, '[worker] wait_jitter must be equal or greater than zero'
self._id = worker_id
self._scheduler = scheduler
self._assistant = assistant
self._stop_requesting_work = False
self.host = socket.gethostname()
self._scheduled_tasks: dict[str, Task] = {}
self._suspended_tasks: dict[str, Task] = {}
self._batch_running_tasks: dict[str, Any] = {}
self._batch_families_sent: set[str] = set()
self._first_task = None
self.add_succeeded = True
self.run_succeeded = True
self.unfulfilled_counts: dict[str, int] = collections.defaultdict(int)
# note that ``signal.signal(signal.SIGUSR1, fn)`` only works inside the main execution thread, which is why we
# provide the ability to conditionally install the hook.
if not self._config.no_install_shutdown_handler:
try:
signal.signal(signal.SIGUSR1, self.handle_interrupt)
signal.siginterrupt(signal.SIGUSR1, False)
except AttributeError:
pass
# Keep info about what tasks are running (could be in other processes)
self._task_result_queue: multiprocessing.Queue[Any] = _fork_context.Queue()
self._running_tasks: dict[str, TaskProcess] = {}
self._idle_since: datetime.datetime | None = None
# mp-safe dictionary for caching completation checks across task processes
self._task_completion_cache = None
if self._config.cache_task_completion:
self._task_completion_cache = _fork_context.Manager().dict()
# Stuff for execution_summary
self._add_task_history: list[Any] = []
self._get_work_response_history: list[Any] = []
def _add_task(self, *args, **kwargs):
"""
Call ``self._scheduler.add_task``, but store the values too so we can
implement :py:func:`luigi.execution_summary.summary`.
"""
task_id = kwargs['task_id']
status = kwargs['status']
runnable = kwargs['runnable']
task = self._scheduled_tasks.get(task_id)
if task:
self._add_task_history.append((task, status, runnable))
kwargs['owners'] = task._owner_list()
if task_id in self._batch_running_tasks:
for batch_task in self._batch_running_tasks.pop(task_id):
self._add_task_history.append((batch_task, status, True))
if task and kwargs.get('params'):
kwargs['param_visibilities'] = task._get_param_visibilities()
self._scheduler.add_task(*args, **kwargs)
logger.info('Informed scheduler that task %s has status %s', task_id, status)
def __enter__(self) -> Worker:
"""
Start the KeepAliveThread.
"""
self._keep_alive_thread = luigi.worker.KeepAliveThread(self._scheduler, self._id, self._config.ping_interval, self._handle_rpc_message)
self._keep_alive_thread.daemon = True
self._keep_alive_thread.start()
return self
def __exit__(self, type: Any, value: Any, traceback: Any) -> Literal[False]:
"""
Stop the KeepAliveThread and kill still running tasks.
"""
self._keep_alive_thread.stop()
self._keep_alive_thread.join()
for task in self._running_tasks.values():
if task.is_alive():
task.terminate()
self._task_result_queue.close()
return False # Don't suppress exception
def _generate_worker_info(self) -> list[tuple[str, Any]]:
# Generate as much info as possible about the worker
# Some of these calls might not be available on all OS's
args = [('salt', f'{random.randrange(0, 10_000_000_000):09d}'), ('workers', self.worker_processes)]
try:
args += [('host', socket.gethostname())]
except BaseException:
pass
try:
args += [('username', getpass.getuser())]
except BaseException:
pass
try:
args += [('pid', os.getpid())]
except BaseException:
pass
try:
sudo_user = os.getenv('SUDO_USER')
if sudo_user:
args.append(('sudo_user', sudo_user))
except BaseException:
pass
return args
def _generate_worker_id(self, worker_info: list[Any]) -> str:
worker_info_str = ', '.join([f'{k}={v}' for k, v in worker_info])
return f'Worker({worker_info_str})'
def _validate_task(self, task: Task) -> None:
if not isinstance(task, Task):
raise luigi.worker.TaskException(f'Can not schedule non-task {task}')
if not task.initialized():
# we can't get the repr of it since it's not initialized...
raise luigi.worker.TaskException(
f'Task of class {task.__class__.__name__} not initialized. Did you override __init__ and forget to call super(...).__init__?'
)
def _log_complete_error(self, task: Task, tb: str) -> None:
log_msg = f'Will not run {task} or any dependencies due to error in complete() method:\n{tb}'
logger.warning(log_msg)
def _log_dependency_error(self, task: Task, tb: str) -> None:
log_msg = f'Will not run {task} or any dependencies due to error in deps() method:\n{tb}'
logger.warning(log_msg)
def _log_unexpected_error(self, task: Task) -> None:
logger.exception('Luigi unexpected framework error while scheduling %s', task) # needs to be called from within except clause
def _announce_scheduling_failure(self, task: Task, expl: Any) -> None:
try:
self._scheduler.announce_scheduling_failure(
worker=self._id,
task_name=str(task),
family=task.task_family,
params=task.to_str_params(only_significant=True),
expl=expl,
owners=task._owner_list(),
)
except Exception:
formatted_traceback = traceback.format_exc()
self._email_unexpected_error(task, formatted_traceback)
raise
def _email_complete_error(self, task: Task, formatted_traceback: str) -> None:
self._announce_scheduling_failure(task, formatted_traceback)
if self._config.send_failure_email:
self._email_error(
task,
formatted_traceback,
subject='Luigi: {task} failed scheduling. Host: {host}',
headline='Will not run {task} or any dependencies due to error in complete() method',
)
def _email_dependency_error(self, task: Task, formatted_traceback: str) -> None:
self._announce_scheduling_failure(task, formatted_traceback)
if self._config.send_failure_email:
self._email_error(
task,
formatted_traceback,
subject='Luigi: {task} failed scheduling. Host: {host}',
headline='Will not run {task} or any dependencies due to error in deps() method',
)
def _email_unexpected_error(self, task: Task, formatted_traceback: str) -> None:
# this sends even if failure e-mails are disabled, as they may indicate
# a more severe failure that may not reach other alerting methods such
# as scheduler batch notification
self._email_error(
task,
formatted_traceback,
subject='Luigi: Framework error while scheduling {task}. Host: {host}',
headline='Luigi framework error',
)
def _email_task_failure(self, task: Task, formatted_traceback: str) -> None:
if self._config.send_failure_email:
self._email_error(
task,
formatted_traceback,
subject='Luigi: {task} FAILED. Host: {host}',
headline='A task failed when running. Most likely run() raised an exception.',
)
def _email_error(self, task: Task, formatted_traceback: str, subject: str, headline: str) -> None:
formatted_subject = subject.format(task=task, host=self.host)
formatted_headline = headline.format(task=task, host=self.host)
command = subprocess.list2cmdline(sys.argv)
message = notifications.format_task_error(formatted_headline, task, command, formatted_traceback)
notifications.send_error_email(formatted_subject, message, task.owner_email)
def _handle_task_load_error(self, exception: Exception, task_ids: list[str]) -> None:
msg = 'Cannot find task(s) sent by scheduler: {}'.format(','.join(task_ids))
logger.exception(msg)
subject = f'Luigi: {msg}'
error_message = notifications.wrap_traceback(exception)
for task_id in task_ids:
self._add_task(
worker=self._id,
task_id=task_id,
status=FAILED,
runnable=False,
expl=error_message,
)
notifications.send_error_email(subject, error_message)
def add(self, task: Task, multiprocess: bool = False, processes: int = 0) -> bool:
"""
Add a Task for the worker to check and possibly schedule and run.
Returns True if task and its dependencies were successfully scheduled or completed before.
"""
if self._first_task is None and hasattr(task, 'task_id'):
self._first_task = task.task_id
self.add_succeeded = True
if multiprocess:
queue: Any = _fork_context.Manager().Queue()
pool: Any = _fork_context.Pool(processes=processes if processes > 0 else None)
else:
queue = luigi.worker.DequeQueue()
pool = luigi.worker.SingleProcessPool()
self._validate_task(task)
pool.apply_async(luigi.worker.check_complete, [task, queue, self._task_completion_cache])
# we track queue size ourselves because len(queue) won't work for multiprocessing
queue_size = 1
try:
seen = {task.task_id}
while queue_size:
current = queue.get()
queue_size -= 1
item, is_complete = current
for next in self._add(item, is_complete):
if next.task_id not in seen:
self._validate_task(next)
seen.add(next.task_id)
pool.apply_async(luigi.worker.check_complete, [next, queue, self._task_completion_cache])
queue_size += 1
except (KeyboardInterrupt, luigi.worker.TaskException):
raise
except Exception as ex:
self.add_succeeded = False
formatted_traceback = traceback.format_exc()
self._log_unexpected_error(task)
task.trigger_event(Event.BROKEN_TASK, task, ex)
self._email_unexpected_error(task, formatted_traceback)
raise
finally:
pool.close()
pool.join()
return self.add_succeeded
def _add_task_batcher(self, task: Task) -> None:
family = task.task_family
if family not in self._batch_families_sent:
task_class = type(task)
batch_param_names = task_class.batch_param_names()
if batch_param_names:
self._scheduler.add_task_batcher(
worker=self._id,
task_family=family,
batched_args=batch_param_names,
max_batch_size=task.max_batch_size,
)
self._batch_families_sent.add(family)
def _add(self, task: Task, is_complete: bool) -> Generator[Task, None, None]:
if self._config.task_limit is not None and len(self._scheduled_tasks) >= self._config.task_limit:
logger.warning('Will not run %s or any dependencies due to exceeded task-limit of %d', task, self._config.task_limit)
deps = None
status = UNKNOWN
runnable = False
else:
formatted_traceback = None
try:
self._check_complete_value(is_complete)
except KeyboardInterrupt:
raise
except luigi.worker.AsyncCompletionException as ex:
formatted_traceback = ex.trace
except BaseException:
formatted_traceback = traceback.format_exc()
if formatted_traceback is not None:
self.add_succeeded = False
self._log_complete_error(task, formatted_traceback)
task.trigger_event(Event.DEPENDENCY_MISSING, task)
self._email_complete_error(task, formatted_traceback)
deps = None
status = UNKNOWN
runnable = False
elif is_complete:
deps = None
status = DONE
runnable = False
task.trigger_event(Event.DEPENDENCY_PRESENT, task)
elif _is_external(task):
deps = None
status = PENDING
runnable = self._config.retry_external_tasks
task.trigger_event(Event.DEPENDENCY_MISSING, task)
logger.warning('Data for %s does not exist (yet?). The task is an external data dependency, so it cannot be run from this luigi process.', task)
else:
try:
deps = task.deps()
self._add_task_batcher(task)
except Exception as ex:
formatted_traceback = traceback.format_exc()
self.add_succeeded = False
self._log_dependency_error(task, formatted_traceback)
task.trigger_event(Event.BROKEN_TASK, task, ex)
self._email_dependency_error(task, formatted_traceback)
deps = None
status = UNKNOWN
runnable = False
else:
status = PENDING
runnable = True
if task.disabled:
status = DISABLED
if deps:
for d in deps:
self._validate_dependency(d)
task.trigger_event(Event.DEPENDENCY_DISCOVERED, task, d)
yield d # return additional tasks to add
deps = [d.task_id for d in deps]
self._scheduled_tasks[task.task_id] = task
self._add_task(
worker=self._id,
task_id=task.task_id,
status=status,
deps=deps,
runnable=runnable,
priority=task.priority,
resources=task.process_resources(),
params=task.to_str_params(),
family=task.task_family,
module=task.task_module,
batchable=task.batchable,
retry_policy_dict=_get_retry_policy_dict(task),
accepts_messages=task.accepts_messages,
)
def _validate_dependency(self, dependency: Task) -> None:
if isinstance(dependency, Target):
raise Exception('requires() can not return Target objects. Wrap it in an ExternalTask class')
elif not isinstance(dependency, Task):
raise Exception(f'requires() must return Task objects but {dependency} is a {type(dependency)}')
def _check_complete_value(self, is_complete: bool | luigi.worker.TracebackWrapper) -> None:
if isinstance(is_complete, luigi.worker.TracebackWrapper):
raise luigi.worker.AsyncCompletionException(is_complete.trace)
if not isinstance(is_complete, bool):
raise Exception(f'Return value of Task.complete() must be boolean (was {is_complete!r})')
def _add_worker(self) -> None:
self._worker_info.append(('first_task', self._first_task))
self._scheduler.add_worker(self._id, self._worker_info)
def _log_remote_tasks(self, get_work_response: GetWorkResponse) -> None:
logger.debug('Done')
logger.debug('There are no more tasks to run at this time')
if get_work_response.running_tasks:
for r in get_work_response.running_tasks:
logger.debug('%s is currently run by worker %s', r['task_id'], r['worker'])
elif get_work_response.n_pending_tasks:
logger.debug('There are %s pending tasks possibly being run by other workers', get_work_response.n_pending_tasks)
if get_work_response.n_unique_pending:
logger.debug('There are %i pending tasks unique to this worker', get_work_response.n_unique_pending)
if get_work_response.n_pending_last_scheduled:
logger.debug('There are %i pending tasks last scheduled by this worker', get_work_response.n_pending_last_scheduled)
def _get_work_task_id(self, get_work_response: dict[str, Any]) -> str | None:
if get_work_response.get('task_id') is not None:
return cast(str, get_work_response['task_id'])
elif 'batch_id' in get_work_response:
try:
task = load_task(
module=get_work_response.get('task_module'),
task_name=get_work_response['task_family'],
params_str=get_work_response['task_params'],
)
except Exception as ex:
self._handle_task_load_error(ex, get_work_response['batch_task_ids'])
self.run_succeeded = False
return None
self._scheduler.add_task(
worker=self._id,
task_id=task.task_id,
module=get_work_response.get('task_module'),
family=get_work_response['task_family'],
params=task.to_str_params(),
status=RUNNING,
batch_id=get_work_response['batch_id'],
)
return cast(str, task.task_id)
else:
return None
def _get_work(self) -> GetWorkResponse:
if self._stop_requesting_work:
return GetWorkResponse(None, 0, 0, 0, 0, WORKER_STATE_DISABLED)
if self.worker_processes > 0:
logger.debug('Asking scheduler for work...')
r = self._scheduler.get_work(
worker=self._id,
host=self.host,
assistant=self._assistant,
current_tasks=list(self._running_tasks.keys()),
)
else:
logger.debug('Checking if tasks are still pending')
r = self._scheduler.count_pending(worker=self._id)
running_tasks = r['running_tasks']
task_id = self._get_work_task_id(r)
self._get_work_response_history.append(
{
'task_id': task_id,
'running_tasks': running_tasks,
}
)
if task_id is not None and task_id not in self._scheduled_tasks:
logger.info('Did not schedule %s, will load it dynamically', task_id)
try:
# TODO: we should obtain the module name from the server!
self._scheduled_tasks[task_id] = load_task(module=r.get('task_module'), task_name=r['task_family'], params_str=r['task_params'])
except TaskClassException as ex:
self._handle_task_load_error(ex, [task_id])
task_id = None
self.run_succeeded = False
if task_id is not None and 'batch_task_ids' in r:
batch_tasks = filter(None, [self._scheduled_tasks.get(batch_id) for batch_id in r['batch_task_ids']])
self._batch_running_tasks[task_id] = batch_tasks
return GetWorkResponse(
task_id=task_id,
running_tasks=running_tasks,
n_pending_tasks=r['n_pending_tasks'],
n_unique_pending=r['n_unique_pending'],
# TODO: For a tiny amount of time (a month?) we'll keep forwards compatibility
# That is you can user a newer client than server (Sep 2016)
n_pending_last_scheduled=r.get('n_pending_last_scheduled', 0),
worker_state=r.get('worker_state', WORKER_STATE_ACTIVE),
)
def _run_task(self, task_id: str) -> None:
if task_id in self._running_tasks:
logger.debug(f'Got already running task id {task_id} from scheduler, taking a break')
next(self._sleeper())
return
task = self._scheduled_tasks[task_id]
task_process = self._create_task_process(task)
self._running_tasks[task_id] = task_process
if task_process.use_multiprocessing:
with fork_lock:
task_process.start()
else:
# Run in the same process
task_process.run()
def _create_task_process(self, task):
message_queue: Any = _fork_context.Queue() if task.accepts_messages else None
reporter = luigi.worker.TaskStatusReporter(self._scheduler, task.task_id, self._id, message_queue)
use_multiprocessing = self._config.force_multiprocessing or bool(self.worker_processes > 1)
return ContextManagedTaskProcess(
self._config.task_process_context,
task,
self._id,
self._task_result_queue,
reporter,
use_multiprocessing=use_multiprocessing,
worker_timeout=self._config.timeout,
check_unfulfilled_deps=self._config.check_unfulfilled_deps,
check_complete_on_run=self._config.check_complete_on_run,
task_completion_cache=self._task_completion_cache,
task_completion_check_at_run=self._config.task_completion_check_at_run,
)
def _purge_children(self) -> None:
"""
Find dead children and put a response on the result queue.
:return:
"""
for task_id, p in self._running_tasks.items():
if not p.is_alive() and p.exitcode:
error_msg = f'Task {task_id} died unexpectedly with exit code {p.exitcode}'
p.task.trigger_event(Event.PROCESS_FAILURE, p.task, error_msg)
elif p.timeout_time is not None and time.time() > float(p.timeout_time) and p.is_alive():
p.terminate()
error_msg = f'Task {task_id} timed out after {p.worker_timeout} seconds and was terminated.'
p.task.trigger_event(Event.TIMEOUT, p.task, error_msg)
else:
continue
logger.info(error_msg)
self._task_result_queue.put((task_id, FAILED, error_msg, [], []))
def _handle_next_task(self) -> None:
"""
We have to catch three ways a task can be "done":
1. normal execution: the task runs/fails and puts a result back on the queue,
2. new dependencies: the task yielded new deps that were not complete and
will be rescheduled and dependencies added,
3. child process dies: we need to catch this separately.
"""
self._idle_since = None
while True:
self._purge_children() # Deal with subprocess failures
try:
task_id, status, expl, missing, new_requirements = self._task_result_queue.get(timeout=self._config.wait_interval)
except Queue.Empty:
return
task = self._scheduled_tasks[task_id]
if not task or task_id not in self._running_tasks:
continue
# Not a running task. Probably already removed.
# Maybe it yielded something?
# external task if run not implemented, retry-able if config option is enabled.
external_task_retryable = _is_external(task) and self._config.retry_external_tasks
if status == FAILED and not external_task_retryable:
self._email_task_failure(task, expl)
new_deps = []
if new_requirements:
new_req = [load_task(module, name, params) for module, name, params in new_requirements]
for t in new_req:
self.add(t)
new_deps = [t.task_id for t in new_req]
self._add_task(
worker=self._id,
task_id=task_id,
status=status,
expl=json.dumps(expl),
resources=task.process_resources(),
runnable=None,
params=task.to_str_params(),
family=task.task_family,
module=task.task_module,
new_deps=new_deps,
assistant=self._assistant,
retry_policy_dict=_get_retry_policy_dict(task),
)
self._running_tasks.pop(task_id)
# re-add task to reschedule missing dependencies
if missing:
reschedule = True
# keep out of infinite loops by not rescheduling too many times
for task_id in missing:
self.unfulfilled_counts[task_id] += 1
if self.unfulfilled_counts[task_id] > self._config.max_reschedules:
reschedule = False
if reschedule:
self.add(task)
self.run_succeeded &= (status == DONE) or (len(new_deps) > 0)
return
def _sleeper(self) -> Generator[None, None, None]:
# TODO is exponential backoff necessary?
while True:
jitter = self._config.wait_jitter
wait_interval = self._config.wait_interval + random.uniform(0, jitter)
logger.debug('Sleeping for %f seconds', wait_interval)
time.sleep(wait_interval)
yield
def _keep_alive(self, get_work_response: Any) -> bool:
"""
Returns true if a worker should stay alive given.
If worker-keep-alive is not set, this will always return false.
For an assistant, it will always return the value of worker-keep-alive.
Otherwise, it will return true for nonzero n_pending_tasks.
If worker-count-uniques is true, it will also
require that one of the tasks is unique to this worker.
"""
if not self._config.keep_alive:
return False
elif self._assistant:
return True
elif self._config.count_last_scheduled:
return cast(bool, get_work_response.n_pending_last_scheduled > 0)
elif self._config.count_uniques:
return cast(bool, get_work_response.n_unique_pending > 0)
elif get_work_response.n_pending_tasks == 0:
return False
elif not self._config.max_keep_alive_idle_duration:
return True
elif not self._idle_since:
return True
else:
time_to_shutdown = self._idle_since + self._config.max_keep_alive_idle_duration - datetime.datetime.now()
logger.debug('[%s] %s until shutdown', self._id, time_to_shutdown)
return time_to_shutdown > datetime.timedelta(0)
def handle_interrupt(self, signum: int, _: Any) -> None:
"""
Stops the assistant from asking for more work on SIGUSR1
"""
if signum == signal.SIGUSR1:
self._start_phasing_out()
def _start_phasing_out(self) -> None:
"""
Go into a mode where we dont ask for more work and quit once existing
tasks are done.
"""
self._config.keep_alive = False
self._stop_requesting_work = True
def run(self) -> bool:
"""
Returns True if all scheduled tasks were executed successfully.
"""
logger.info('Running Worker with %d processes', self.worker_processes)
sleeper = self._sleeper()
self.run_succeeded = True
self._add_worker()
while True:
while len(self._running_tasks) >= self.worker_processes > 0:
logger.debug('%d running tasks, waiting for next task to finish', len(self._running_tasks))
self._handle_next_task()
get_work_response = self._get_work()
if get_work_response.worker_state == WORKER_STATE_DISABLED:
self._start_phasing_out()
if get_work_response.task_id is None:
if not self._stop_requesting_work:
self._log_remote_tasks(get_work_response)
if len(self._running_tasks) == 0:
self._idle_since = self._idle_since or datetime.datetime.now()
if self._keep_alive(get_work_response):
next(sleeper)
continue
else:
break
else:
self._handle_next_task()
continue
# task_id is not None:
logger.debug('Pending tasks: %s', get_work_response.n_pending_tasks)
self._run_task(get_work_response.task_id)
while len(self._running_tasks):
logger.debug('Shut down Worker, %d more tasks to go', len(self._running_tasks))
self._handle_next_task()
return self.run_succeeded
def _handle_rpc_message(self, message: dict[str, Any]) -> None:
logger.info(f'Worker {self._id} got message {message}')
# the message is a dict {'name': , 'kwargs': }
name = message['name']
kwargs = message['kwargs']
# find the function and check if it's callable and configured to work
# as a message callback
func = getattr(self, name, None)
tpl = (self._id, name)
if not callable(func):
logger.error("Worker {} has no function '{}'".format(*tpl))
elif not getattr(func, 'is_rpc_message_callback', False):
logger.error("Worker {} function '{}' is not available as rpc message callback".format(*tpl))
else:
logger.info("Worker {} successfully dispatched rpc message to function '{}'".format(*tpl))
func(**kwargs)
@luigi.worker.rpc_message_callback
def set_worker_processes(self, n: int) -> None:
# set the new value
self.worker_processes = max(1, n)
# tell the scheduler
self._scheduler.add_worker(self._id, {'workers': self.worker_processes})
@luigi.worker.rpc_message_callback
def dispatch_scheduler_message(self, task_id: str, message_id: str, content: str, **kwargs: Any) -> None:
task_id = str(task_id)
if task_id in self._running_tasks:
task_process = self._running_tasks[task_id]
if task_process.status_reporter.scheduler_messages:
message = luigi.worker.SchedulerMessage(self._scheduler, task_id, message_id, content, **kwargs)
task_process.status_reporter.scheduler_messages.put(message)
================================================
FILE: gokart/workspace_management.py
================================================
from __future__ import annotations
import itertools
import os
import pathlib
from logging import getLogger
from typing import Any
import gokart
from gokart.utils import flatten
logger = getLogger(__name__)
def _get_all_output_file_paths(task: gokart.TaskOnKart[Any]) -> list[str]:
output_paths = [t.path() for t in flatten(task.output())]
children = flatten(task.requires())
output_paths.extend(itertools.chain.from_iterable([_get_all_output_file_paths(child) for child in children]))
return output_paths
def delete_local_unnecessary_outputs(task: gokart.TaskOnKart[Any]) -> None:
task.make_unique_id() # this is required to make unique ids.
all_files = {str(path) for path in pathlib.Path(task.workspace_directory).rglob('*.*')}
log_files = {str(path) for path in pathlib.Path(os.path.join(task.workspace_directory, 'log')).rglob('*.*')}
necessary_files = set(_get_all_output_file_paths(task))
unnecessary_files = all_files - necessary_files - log_files
if len(unnecessary_files) == 0:
logger.info('all files are necessary for this task.')
else:
logger.info(f'remove following files: {os.linesep} {os.linesep.join(unnecessary_files)}')
for file in unnecessary_files:
os.remove(file)
================================================
FILE: gokart/zip_client.py
================================================
from __future__ import annotations
import os
import shutil
import zipfile
from abc import abstractmethod
from typing import IO
def _unzip_file(fp: str | IO[bytes] | os.PathLike[str], extract_dir: str) -> None:
zip_file = zipfile.ZipFile(fp)
zip_file.extractall(extract_dir)
zip_file.close()
class ZipClient:
@abstractmethod
def exists(self) -> bool:
pass
@abstractmethod
def make_archive(self) -> None:
pass
@abstractmethod
def unpack_archive(self) -> None:
pass
@abstractmethod
def remove(self) -> None:
pass
@property
@abstractmethod
def path(self) -> str:
pass
class LocalZipClient(ZipClient):
def __init__(self, file_path: str, temporary_directory: str) -> None:
self._file_path = file_path
self._temporary_directory = temporary_directory
def exists(self) -> bool:
return os.path.exists(self._file_path)
def make_archive(self) -> None:
[base, extension] = os.path.splitext(self._file_path)
shutil.make_archive(base_name=base, format=extension[1:], root_dir=self._temporary_directory)
def unpack_archive(self) -> None:
_unzip_file(fp=self._file_path, extract_dir=self._temporary_directory)
def remove(self) -> None:
shutil.rmtree(self._file_path, ignore_errors=True)
@property
def path(self) -> str:
return self._file_path
================================================
FILE: gokart/zip_client_util.py
================================================
from __future__ import annotations
from gokart.object_storage import ObjectStorage
from gokart.zip_client import LocalZipClient, ZipClient
def make_zip_client(file_path: str, temporary_directory: str) -> ZipClient:
if ObjectStorage.if_object_storage_path(file_path):
return ObjectStorage.get_zip_client(file_path=file_path, temporary_directory=temporary_directory)
return LocalZipClient(file_path=file_path, temporary_directory=temporary_directory)
================================================
FILE: luigi.cfg
================================================
[core]
autoload_range: false
================================================
FILE: pyproject.toml
================================================
[build-system]
requires = ["hatchling", "uv-dynamic-versioning"]
build-backend = "hatchling.build"
[project]
name = "gokart"
description="Gokart solves reproducibility, task dependencies, constraints of good code, and ease of use for Machine Learning Pipeline. [Documentation](https://gokart.readthedocs.io/en/latest/)"
authors = [
{name = "M3, inc."}
]
license = "MIT"
readme = "README.md"
requires-python = ">=3.10, <4"
dependencies = [
"luigi>=3.8.0",
"boto3",
"slack-sdk",
"pandas",
"numpy",
"google-auth",
"pyarrow",
"google-api-python-client",
"APScheduler",
"redis",
"dill",
"backoff",
"typing-extensions>=4.11.0; python_version<'3.13'",
]
classifiers = [
"Development Status :: 5 - Production/Stable",
"License :: OSI Approved :: MIT License",
"Programming Language :: Python :: 3",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
"Programming Language :: Python :: 3.12",
"Programming Language :: Python :: 3.13",
"Programming Language :: Python :: 3.14",
]
dynamic = ["version"]
[project.optional-dependencies]
polars = ["polars>=0.19.0"]
[project.urls]
Homepage = "https://github.com/m3dev/gokart"
Repository = "https://github.com/m3dev/gokart"
Documentation = "https://gokart.readthedocs.io/en/latest/"
[dependency-groups]
test = [
"fakeredis",
"lupa",
"matplotlib",
"moto>=4.0", # for use mock_aws api
"mypy",
"polars>=0.19.0",
"pytest",
"pytest-cov",
"pytest-xdist",
"testfixtures",
"toml",
"types-redis",
"typing-extensions>=4.11.0",
]
lint = [
"ruff",
"mypy",
]
[tool.uv]
default-groups = ['test', 'lint']
cache-keys = [ { file = "pyproject.toml" }, { git = true } ]
[tool.hatch.version]
source = "uv-dynamic-versioning"
[tool.uv-dynamic-versioning]
enable = true
[tool.hatch.build.targets.sdist]
include = [
"/LICENSE",
"/README.md",
"/examples",
"/gokart",
"/test",
]
[tool.ruff]
line-length = 160
exclude = ["venv/*", "tox/*", "examples/*"]
[tool.ruff.lint]
# All the rules are listed on https://docs.astral.sh/ruff/rules/
extend-select = [
"B", # bugbear
"I", # isort
"UP", # pyupgrade, upgrade syntax for newer versions of the language.
]
# B006: Do not use mutable data structures for argument defaults. They are created during function definition time. All calls to the function reuse this one instance of that data structure, persisting changes between them.
# B008 Do not perform function calls in argument defaults. The call is performed only once at function definition time. All calls to your function will reuse the result of that definition-time function call. If this is intended, assign the function call to a module-level variable and use that variable as a default value.
ignore = ["B006", "B008"]
[tool.ruff.format]
quote-style = "single"
[tool.mypy]
ignore_missing_imports = true
check_untyped_defs = true
warn_unused_configs = true
warn_redundant_casts = true
no_implicit_optional = true
strict_optional = true
strict_equality = true
warn_unused_ignores = true
warn_return_any = true
disallow_incomplete_defs = true
disallow_any_generics = true
[tool.pytest.ini_options]
testpaths = ["test"]
addopts = "-n auto -s -v --durations=0"
================================================
FILE: test/__init__.py
================================================
================================================
FILE: test/config/__init__.py
================================================
from pathlib import Path
from typing import Final
CONFIG_DIR: Final[Path] = Path(__file__).parent.resolve()
PYPROJECT_TOML: Final[Path] = CONFIG_DIR / 'pyproject.toml'
PYPROJECT_TOML_SET_DISALLOW_MISSING_PARAMETERS: Final[Path] = CONFIG_DIR / 'pyproject_disallow_missing_parameters.toml'
TEST_CONFIG_INI: Final[Path] = CONFIG_DIR / 'test_config.ini'
================================================
FILE: test/config/pyproject.toml
================================================
[tool.mypy]
plugins = ["gokart.mypy"]
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"]
================================================
FILE: test/config/pyproject_disallow_missing_parameters.toml
================================================
[tool.mypy]
plugins = ["gokart.mypy"]
[[tool.mypy.overrides]]
ignore_missing_imports = true
module = ["pandas.*", "apscheduler.*", "dill.*", "boto3.*", "testfixtures.*", "luigi.*"]
[tool.gokart-mypy]
disallow_missing_parameters = true
================================================
FILE: test/config/test_config.ini
================================================
[test_read_config._DummyTask]
param = ${test_param}
[test_build._DummyTask]
param = ${test_param}
================================================
FILE: test/conflict_prevention_lock/__init__.py
================================================
================================================
FILE: test/conflict_prevention_lock/test_task_lock.py
================================================
import random
import unittest
from typing import Any
from unittest.mock import patch
import gokart
from gokart.conflict_prevention_lock.task_lock import RedisClient, TaskLockParams, make_task_lock_key, make_task_lock_params, make_task_lock_params_for_run
class TestRedisClient(unittest.TestCase):
@staticmethod
def _get_randint(host, port):
return random.randint(0, 100000)
def test_redis_client_is_singleton(self):
with patch('redis.Redis') as mock:
mock.side_effect = self._get_randint
redis_client_0_0 = RedisClient(host='host_0', port=123)
redis_client_1 = RedisClient(host='host_1', port=123)
redis_client_0_1 = RedisClient(host='host_0', port=123)
self.assertNotEqual(redis_client_0_0, redis_client_1)
self.assertEqual(redis_client_0_0, redis_client_0_1)
self.assertEqual(redis_client_0_0.get_redis_client(), redis_client_0_1.get_redis_client())
class TestMakeRedisKey(unittest.TestCase):
def test_make_redis_key(self):
result = make_task_lock_key(file_path='gs://test_ll/dir/fname.pkl', unique_id='12345')
self.assertEqual(result, 'fname_12345')
class TestMakeRedisParams(unittest.TestCase):
def test_make_task_lock_params_with_valid_host(self):
result = make_task_lock_params(
file_path='gs://aaa.pkl', unique_id='123', redis_host='0.0.0.0', redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False
)
expected = TaskLockParams(
redis_host='0.0.0.0',
redis_port=12345,
redis_key='aaa_123',
should_task_lock=True,
redis_timeout=180,
raise_task_lock_exception_on_collision=False,
lock_extend_seconds=10,
)
self.assertEqual(result, expected)
def test_make_task_lock_params_with_no_host(self):
result = make_task_lock_params(
file_path='gs://aaa.pkl', unique_id='123', redis_host=None, redis_port=12345, redis_timeout=180, raise_task_lock_exception_on_collision=False
)
expected = TaskLockParams(
redis_host=None,
redis_port=12345,
redis_key='aaa_123',
should_task_lock=False,
redis_timeout=180,
raise_task_lock_exception_on_collision=False,
lock_extend_seconds=10,
)
self.assertEqual(result, expected)
def test_assert_when_redis_timeout_is_too_short(self):
with self.assertRaises(AssertionError):
make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
redis_timeout=2,
)
class TestMakeTaskLockParamsForRun(unittest.TestCase):
def test_make_task_lock_params_for_run(self):
class _SampleDummyTask(gokart.TaskOnKart[Any]):
pass
task_self = _SampleDummyTask(
redis_host='0.0.0.0',
redis_port=12345,
redis_timeout=180,
)
result = make_task_lock_params_for_run(task_self=task_self, lock_extend_seconds=10)
expected = TaskLockParams(
redis_host='0.0.0.0',
redis_port=12345,
redis_timeout=180,
redis_key='_SampleDummyTask_7e857f231830ca0fd6cf829d99f43961-run',
should_task_lock=True,
raise_task_lock_exception_on_collision=True,
lock_extend_seconds=10,
)
self.assertEqual(result, expected)
================================================
FILE: test/conflict_prevention_lock/test_task_lock_wrappers.py
================================================
import time
import unittest
from unittest.mock import MagicMock, patch
import fakeredis
from gokart.conflict_prevention_lock.task_lock import make_task_lock_params
from gokart.conflict_prevention_lock.task_lock_wrappers import wrap_dump_with_lock, wrap_load_with_lock, wrap_remove_with_lock
def _sample_func_with_error(a: int, b: str) -> None:
raise Exception()
def _sample_long_func(a: int, b: str) -> dict[str, int | str]:
time.sleep(2.7)
return dict(a=a, b=b)
class TestWrapDumpWithLock(unittest.TestCase):
def test_no_redis(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host=None,
redis_port=None,
)
mock_func = MagicMock()
wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
def test_use_redis(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.side_effect = fakeredis.FakeRedis
mock_func = MagicMock()
wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
def test_if_func_is_skipped_when_cache_already_exists(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.side_effect = fakeredis.FakeRedis
mock_func = MagicMock()
wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: True)(123, b='abc')
mock_func.assert_not_called()
def test_check_lock_extended(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
redis_timeout=2,
lock_extend_seconds=1,
)
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.side_effect = fakeredis.FakeRedis
wrap_dump_with_lock(func=_sample_long_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc')
def test_lock_is_removed_after_func_is_finished(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
server = fakeredis.FakeServer()
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port)
mock_func = MagicMock()
wrap_dump_with_lock(func=mock_func, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
fake_redis = fakeredis.FakeStrictRedis(server=server)
with self.assertRaises(KeyError):
fake_redis[task_lock_params.redis_key]
def test_lock_is_removed_after_func_is_finished_with_error(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
server = fakeredis.FakeServer()
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port)
try:
wrap_dump_with_lock(func=_sample_func_with_error, task_lock_params=task_lock_params, exist_check=lambda: False)(123, b='abc')
except Exception:
fake_redis = fakeredis.FakeStrictRedis(server=server)
with self.assertRaises(KeyError):
fake_redis[task_lock_params.redis_key]
class TestWrapLoadWithLock(unittest.TestCase):
def test_no_redis(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host=None,
redis_port=None,
)
mock_func = MagicMock()
resulted = wrap_load_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
self.assertEqual(resulted, mock_func())
def test_use_redis(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.side_effect = fakeredis.FakeRedis
mock_func = MagicMock()
resulted = wrap_load_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
self.assertEqual(resulted, mock_func())
def test_check_lock_extended(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
redis_timeout=2,
lock_extend_seconds=1,
)
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.side_effect = fakeredis.FakeRedis
resulted = wrap_load_with_lock(func=_sample_long_func, task_lock_params=task_lock_params)(123, b='abc')
expected = dict(a=123, b='abc')
self.assertEqual(resulted, expected)
def test_lock_is_removed_after_func_is_finished(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
server = fakeredis.FakeServer()
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port)
mock_func = MagicMock()
resulted = wrap_load_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
self.assertEqual(resulted, mock_func())
fake_redis = fakeredis.FakeStrictRedis(server=server)
with self.assertRaises(KeyError):
fake_redis[task_lock_params.redis_key]
def test_lock_is_removed_after_func_is_finished_with_error(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
server = fakeredis.FakeServer()
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port)
try:
wrap_load_with_lock(func=_sample_func_with_error, task_lock_params=task_lock_params)(123, b='abc')
except Exception:
fake_redis = fakeredis.FakeStrictRedis(server=server)
with self.assertRaises(KeyError):
fake_redis[task_lock_params.redis_key]
class TestWrapRemoveWithLock(unittest.TestCase):
def test_no_redis(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host=None,
redis_port=None,
)
mock_func = MagicMock()
resulted = wrap_remove_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
self.assertEqual(resulted, mock_func())
def test_use_redis(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.side_effect = fakeredis.FakeRedis
mock_func = MagicMock()
resulted = wrap_remove_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
self.assertEqual(resulted, mock_func())
def test_check_lock_extended(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
redis_timeout=2,
lock_extend_seconds=1,
)
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.side_effect = fakeredis.FakeRedis
resulted = wrap_remove_with_lock(func=_sample_long_func, task_lock_params=task_lock_params)(123, b='abc')
expected = dict(a=123, b='abc')
self.assertEqual(resulted, expected)
def test_lock_is_removed_after_func_is_finished(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
server = fakeredis.FakeServer()
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port)
mock_func = MagicMock()
resulted = wrap_remove_with_lock(func=mock_func, task_lock_params=task_lock_params)(123, b='abc')
mock_func.assert_called_once()
called_args, called_kwargs = mock_func.call_args
self.assertTupleEqual(called_args, (123,))
self.assertDictEqual(called_kwargs, dict(b='abc'))
self.assertEqual(resulted, mock_func())
fake_redis = fakeredis.FakeStrictRedis(server=server)
with self.assertRaises(KeyError):
fake_redis[task_lock_params.redis_key]
def test_lock_is_removed_after_func_is_finished_with_error(self):
task_lock_params = make_task_lock_params(
file_path='test_dir/test_file.pkl',
unique_id='123abc',
redis_host='0.0.0.0',
redis_port=12345,
)
server = fakeredis.FakeServer()
with patch('gokart.conflict_prevention_lock.task_lock.redis.Redis') as redis_mock:
redis_mock.return_value = fakeredis.FakeRedis(server=server, host=task_lock_params.redis_host, port=task_lock_params.redis_port)
try:
wrap_remove_with_lock(func=_sample_func_with_error, task_lock_params=task_lock_params)(123, b='abc')
except Exception:
fake_redis = fakeredis.FakeStrictRedis(server=server)
with self.assertRaises(KeyError):
fake_redis[task_lock_params.redis_key]
================================================
FILE: test/file_processor/__init__.py
================================================
================================================
FILE: test/file_processor/test_base.py
================================================
"""Tests for base file processors (non-DataFrame processors)."""
from __future__ import annotations
import os
import tempfile
import unittest
from collections.abc import Callable
import boto3
from luigi import LocalTarget
from moto import mock_aws
from gokart.file_processor import PickleFileProcessor
from gokart.object_storage import ObjectStorage
class TestPickleFileProcessor(unittest.TestCase):
def test_dump_and_load_normal_obj(self):
var = 'abc'
processor = PickleFileProcessor()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.pkl'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(var, f)
with local_target.open('r') as f:
loaded = processor.load(f)
self.assertEqual(loaded, var)
def test_dump_and_load_class(self):
import functools
def plus1(func: Callable[..., int]) -> Callable[..., int]:
@functools.wraps(func)
def wrapped() -> int:
ret = func()
return ret + 1
return wrapped
class A:
def __init__(self) -> None:
self.run = plus1(self.run) # type: ignore
def run(self) -> int: # type: ignore
return 1
obj = A()
processor = PickleFileProcessor()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.pkl'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(obj, f)
with local_target.open('r') as f:
loaded = processor.load(f)
self.assertEqual(loaded.run(), obj.run())
@mock_aws
def test_dump_and_load_with_readables3file(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
file_path = os.path.join('s3://test/', 'test.pkl')
var = 'abc'
processor = PickleFileProcessor()
target = ObjectStorage.get_object_storage_target(file_path, processor.format())
with target.open('w') as f:
processor.dump(var, f)
with target.open('r') as f:
loaded = processor.load(f)
self.assertEqual(loaded, var)
================================================
FILE: test/file_processor/test_factory.py
================================================
"""Tests for file processor factory function."""
from __future__ import annotations
import unittest
from gokart.file_processor import (
CsvFileProcessor,
FeatherFileProcessor,
GzipFileProcessor,
JsonFileProcessor,
NpzFileProcessor,
ParquetFileProcessor,
TextFileProcessor,
make_file_processor,
)
class TestMakeFileProcessor(unittest.TestCase):
def test_make_file_processor_with_txt_extension(self):
processor = make_file_processor('test.txt', store_index_in_feather=False)
self.assertIsInstance(processor, TextFileProcessor)
def test_make_file_processor_with_csv_extension(self):
processor = make_file_processor('test.csv', store_index_in_feather=False)
self.assertIsInstance(processor, CsvFileProcessor)
def test_make_file_processor_with_gz_extension(self):
processor = make_file_processor('test.gz', store_index_in_feather=False)
self.assertIsInstance(processor, GzipFileProcessor)
def test_make_file_processor_with_json_extension(self):
processor = make_file_processor('test.json', store_index_in_feather=False)
self.assertIsInstance(processor, JsonFileProcessor)
def test_make_file_processor_with_ndjson_extension(self):
processor = make_file_processor('test.ndjson', store_index_in_feather=False)
self.assertIsInstance(processor, JsonFileProcessor)
def test_make_file_processor_with_npz_extension(self):
processor = make_file_processor('test.npz', store_index_in_feather=False)
self.assertIsInstance(processor, NpzFileProcessor)
def test_make_file_processor_with_parquet_extension(self):
processor = make_file_processor('test.parquet', store_index_in_feather=False)
self.assertIsInstance(processor, ParquetFileProcessor)
def test_make_file_processor_with_feather_extension(self):
processor = make_file_processor('test.feather', store_index_in_feather=True)
self.assertIsInstance(processor, FeatherFileProcessor)
def test_make_file_processor_with_unsupported_extension(self):
with self.assertRaises(AssertionError):
make_file_processor('test.unsupported', store_index_in_feather=False)
================================================
FILE: test/file_processor/test_pandas.py
================================================
"""Tests for pandas-specific file processors."""
from __future__ import annotations
import tempfile
import unittest
import pandas as pd
import pytest
from luigi import LocalTarget
from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor
class TestCsvFileProcessor(unittest.TestCase):
def test_dump_csv_with_utf8(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
# read with utf-8 to check if the file is dumped with utf8
loaded_df = pd.read_csv(temp_path, encoding='utf-8')
pd.testing.assert_frame_equal(df, loaded_df)
def test_dump_csv_with_cp932(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor(encoding='cp932')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
# read with cp932 to check if the file is dumped with cp932
loaded_df = pd.read_csv(temp_path, encoding='cp932')
pd.testing.assert_frame_equal(df, loaded_df)
def test_load_csv_with_utf8(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor()
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.to_csv(temp_path, encoding='utf-8', index=False)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
# read with utf-8 to check if the file is dumped with utf8
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)
def test_load_csv_with_cp932(self):
df = pd.DataFrame({'あ': [1, 2, 3], 'い': [4, 5, 6]})
processor = CsvFileProcessor(encoding='cp932')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.to_csv(temp_path, encoding='cp932', index=False)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
# read with cp932 to check if the file is dumped with cp932
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)
class TestJsonFileProcessor:
@pytest.mark.parametrize(
'orient,input_data,expected_json',
[
pytest.param(
None,
pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}),
'{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}',
id='With Default Orient for DataFrame',
),
pytest.param(
'records',
pd.DataFrame({'A': [1, 2, 3], 'B': [4, 5, 6]}),
'{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n',
id='With Records Orient for DataFrame',
),
pytest.param(None, {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":{"0":1,"1":2,"2":3},"B":{"0":4,"1":5,"2":6}}', id='With Default Orient for Dict'),
pytest.param('records', {'A': [1, 2, 3], 'B': [4, 5, 6]}, '{"A":1,"B":4}\n{"A":2,"B":5}\n{"A":3,"B":6}\n', id='With Records Orient for Dict'),
pytest.param(None, {}, '{}', id='With Default Orient for Empty Dict'),
pytest.param('records', {}, '\n', id='With Records Orient for Empty Dict'),
],
)
def test_dump_and_load_json(self, orient, input_data, expected_json):
processor = JsonFileProcessor(orient=orient)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(input_data, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
f.seek(0)
loaded_json = f.read().decode('utf-8')
assert loaded_json == expected_json
df_input = pd.DataFrame(input_data)
pd.testing.assert_frame_equal(df_input, loaded_df)
class TestFeatherFileProcessor(unittest.TestCase):
def test_feather_should_return_same_dataframe(self):
df = pd.DataFrame({'a': [1]})
processor = FeatherFileProcessor(store_index_in_feather=True)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)
def test_feather_should_save_index_name(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='index_name'))
processor = FeatherFileProcessor(store_index_in_feather=True)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
pd.testing.assert_frame_equal(df, loaded_df)
def test_feather_should_raise_error_index_name_is_None(self):
df = pd.DataFrame({'a': [1]}, index=pd.Index([1], name='None'))
processor = FeatherFileProcessor(store_index_in_feather=True)
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
with self.assertRaises(AssertionError):
processor.dump(df, f)
================================================
FILE: test/file_processor/test_polars.py
================================================
"""Tests for polars-specific file processors."""
from __future__ import annotations
import tempfile
from typing import TYPE_CHECKING
import pandas as pd
import pytest
from luigi import LocalTarget
from gokart.file_processor import CsvFileProcessor, FeatherFileProcessor, JsonFileProcessor, ParquetFileProcessor
if TYPE_CHECKING:
import polars as pl
try:
import polars as pl
HAS_POLARS = True
except ImportError:
HAS_POLARS = False
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
class TestCsvFileProcessorWithPolars:
"""Tests for CsvFileProcessor with polars support"""
def test_dump_polars_dataframe(self):
"""Test dumping a polars DataFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = CsvFileProcessor(dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
# Verify file was created and can be read by polars
loaded_df = pl.read_csv(temp_path)
assert loaded_df.equals(df)
def test_load_polars_dataframe(self):
"""Test loading a CSV as polars DataFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = CsvFileProcessor(dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.write_csv(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_and_load_polars_roundtrip(self):
"""Test roundtrip dump and load with polars"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = CsvFileProcessor(dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_polars_with_pandas_load(self):
"""Test that polars dump can be loaded by pandas processor"""
df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor_polars = CsvFileProcessor(dataframe_type='polars')
processor_pandas = CsvFileProcessor(dataframe_type='pandas')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
# Dump with polars
local_target = LocalTarget(path=temp_path, format=processor_polars.format())
with local_target.open('w') as f:
processor_polars.dump(df_polars, f)
# Load with pandas
with local_target.open('r') as f:
loaded_df = processor_pandas.load(f)
assert isinstance(loaded_df, pd.DataFrame)
# Compare values
df_polars.equals(pl.from_pandas(loaded_df))
def test_polars_with_different_separator(self):
"""Test polars with TSV (tab-separated values)"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = CsvFileProcessor(sep='\t', dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.tsv'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_error_when_polars_not_available_for_load(self):
"""Test error message when polars is requested but a polars operation fails"""
# This test is a bit tricky since polars IS installed in this test class
# We'll just verify the processor accepts the parameter
processor = CsvFileProcessor(dataframe_type='polars')
assert processor._dataframe_type == 'polars'
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
class TestJsonFileProcessorWithPolars:
"""Tests for JsonFileProcessor with polars support"""
def test_dump_polars_dataframe(self):
"""Test dumping a polars DataFrame to JSON"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = JsonFileProcessor(orient=None, dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
# Verify file was created and can be read by polars
loaded_df = pl.read_json(temp_path)
assert loaded_df.equals(df)
def test_load_polars_dataframe(self):
"""Test loading a JSON as polars DataFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = JsonFileProcessor(orient=None, dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
df.write_json(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_and_load_polars_roundtrip(self):
"""Test roundtrip dump and load with polars"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = JsonFileProcessor(orient=None, dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_and_load_ndjson_with_polars(self):
"""Test ndjson (records orient) with polars"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = JsonFileProcessor(orient='records', dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.ndjson'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_polars_with_pandas_load(self):
"""Test that polars dump can be loaded by pandas processor"""
df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor_polars = JsonFileProcessor(orient=None, dataframe_type='polars')
processor_pandas = JsonFileProcessor(orient=None, dataframe_type='pandas')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
# Dump with polars
local_target = LocalTarget(path=temp_path, format=processor_polars.format())
with local_target.open('w') as f:
processor_polars.dump(df_polars, f)
# Load with pandas
with local_target.open('r') as f:
loaded_df = processor_pandas.load(f)
assert isinstance(loaded_df, pd.DataFrame)
# Compare values
assert list(loaded_df['a']) == [1, 2, 3]
assert list(loaded_df['b']) == [4, 5, 6]
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
class TestParquetFileProcessorWithPolars:
"""Tests for ParquetFileProcessor with polars support"""
def test_dump_polars_dataframe(self):
"""Test dumping a polars DataFrame to Parquet"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = ParquetFileProcessor(dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
# Verify file was created and can be read by polars
loaded_df = pl.read_parquet(temp_path)
assert loaded_df.equals(df)
def test_load_polars_dataframe(self):
"""Test loading a Parquet as polars DataFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = ParquetFileProcessor(dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
df.write_parquet(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_and_load_polars_roundtrip(self):
"""Test roundtrip dump and load with polars"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = ParquetFileProcessor(dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_polars_with_pandas_load(self):
"""Test that polars dump can be loaded by pandas processor"""
df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor_polars = ParquetFileProcessor(dataframe_type='polars')
processor_pandas = ParquetFileProcessor(dataframe_type='pandas')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
# Dump with polars
local_target = LocalTarget(path=temp_path, format=processor_polars.format())
with local_target.open('w') as f:
processor_polars.dump(df_polars, f)
# Load with pandas
with local_target.open('r') as f:
loaded_df = processor_pandas.load(f)
assert isinstance(loaded_df, pd.DataFrame)
df_polars.equals(pl.from_pandas(loaded_df))
def test_parquet_with_compression(self):
"""Test polars with parquet compression"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = ParquetFileProcessor(compression='gzip', dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
class TestFeatherFileProcessorWithPolars:
"""Tests for FeatherFileProcessor with polars support"""
def test_dump_polars_dataframe(self):
"""Test dumping a polars DataFrame to Feather"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
# Verify file was created and can be read by polars
loaded_df = pl.read_ipc(temp_path)
assert loaded_df.equals(df)
def test_load_polars_dataframe(self):
"""Test loading a Feather as polars DataFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
df.write_ipc(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_and_load_polars_roundtrip(self):
"""Test roundtrip dump and load with polars"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(df, f)
with local_target.open('r') as f:
loaded_df = processor.load(f)
assert isinstance(loaded_df, pl.DataFrame)
assert loaded_df.equals(df)
def test_dump_polars_with_pandas_load(self):
"""Test that polars dump can be loaded by pandas processor"""
df_polars = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor_polars = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars')
processor_pandas = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='pandas')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
# Dump with polars
local_target = LocalTarget(path=temp_path, format=processor_polars.format())
with local_target.open('w') as f:
processor_polars.dump(df_polars, f)
# Load with pandas
with local_target.open('r') as f:
loaded_df = processor_pandas.load(f)
assert isinstance(loaded_df, pd.DataFrame)
# Compare values
df_polars.equals(pl.from_pandas(loaded_df))
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
class TestLazyFrameSupport:
"""Tests for LazyFrame support in file processors using dataframe_type='polars-lazy'"""
def test_csv_load_lazy(self):
"""Test loading CSV as LazyFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = CsvFileProcessor(dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
df.write_csv(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded = processor.load(f)
assert isinstance(loaded, pl.LazyFrame)
assert loaded.collect().equals(df)
def test_csv_dump_lazyframe(self):
"""Test dumping a LazyFrame to CSV"""
lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy()
processor = CsvFileProcessor(dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.csv'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(lf, f)
# Verify file was created and can be read
loaded_df = pl.read_csv(temp_path)
assert loaded_df.equals(lf.collect())
def test_parquet_load_lazy(self):
"""Test loading Parquet as LazyFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = ParquetFileProcessor(dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
df.write_parquet(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded = processor.load(f)
assert isinstance(loaded, pl.LazyFrame)
assert loaded.collect().equals(df)
def test_parquet_dump_lazyframe(self):
"""Test dumping a LazyFrame to Parquet"""
lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy()
processor = ParquetFileProcessor(dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(lf, f)
# Verify file was created and can be read
loaded_df = pl.read_parquet(temp_path)
assert loaded_df.equals(lf.collect())
def test_feather_load_lazy(self):
"""Test loading Feather as LazyFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
df.write_ipc(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded = processor.load(f)
assert isinstance(loaded, pl.LazyFrame)
assert loaded.collect().equals(df)
def test_feather_dump_lazyframe(self):
"""Test dumping a LazyFrame to Feather"""
lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy()
processor = FeatherFileProcessor(store_index_in_feather=False, dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.feather'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(lf, f)
# Verify file was created and can be read
loaded_df = pl.read_ipc(temp_path)
assert loaded_df.equals(lf.collect())
def test_json_load_lazy_ndjson(self):
"""Test loading NDJSON as LazyFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = JsonFileProcessor(orient='records', dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.ndjson'
df.write_ndjson(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded = processor.load(f)
assert isinstance(loaded, pl.LazyFrame)
assert loaded.collect().equals(df)
def test_json_dump_lazyframe_ndjson(self):
"""Test dumping a LazyFrame to NDJSON"""
lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy()
processor = JsonFileProcessor(orient='records', dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.ndjson'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(lf, f)
# Verify file was created and can be read
loaded_df = pl.read_ndjson(temp_path)
assert loaded_df.equals(lf.collect())
def test_json_load_lazy_standard(self):
"""Test loading standard JSON (orient=None) as LazyFrame"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = JsonFileProcessor(orient=None, dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
df.write_json(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded = processor.load(f)
assert isinstance(loaded, pl.LazyFrame)
assert loaded.collect().equals(df)
def test_json_dump_lazyframe_standard(self):
"""Test dumping a LazyFrame to standard JSON (orient=None)"""
lf = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]}).lazy()
processor = JsonFileProcessor(orient=None, dataframe_type='polars-lazy')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.json'
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('w') as f:
processor.dump(lf, f)
# Verify file was created and can be read
loaded_df = pl.read_json(temp_path)
assert loaded_df.equals(lf.collect())
def test_polars_returns_dataframe(self):
"""Test that dataframe_type='polars' returns DataFrame (not LazyFrame)"""
df = pl.DataFrame({'a': [1, 2, 3], 'b': [4, 5, 6]})
processor = ParquetFileProcessor(dataframe_type='polars')
with tempfile.TemporaryDirectory() as temp_dir:
temp_path = f'{temp_dir}/temp.parquet'
df.write_parquet(temp_path)
local_target = LocalTarget(path=temp_path, format=processor.format())
with local_target.open('r') as f:
loaded = processor.load(f)
assert isinstance(loaded, pl.DataFrame)
assert loaded.equals(df)
================================================
FILE: test/in_memory/test_in_memory_target.py
================================================
from datetime import datetime
from time import sleep
import pytest
from gokart.conflict_prevention_lock.task_lock import TaskLockParams
from gokart.in_memory import InMemoryCacheRepository, InMemoryTarget, make_in_memory_target
class TestInMemoryTarget:
@pytest.fixture
def task_lock_params(self) -> TaskLockParams:
return TaskLockParams(
redis_host=None,
redis_port=None,
redis_timeout=None,
redis_key='dummy',
should_task_lock=False,
raise_task_lock_exception_on_collision=False,
lock_extend_seconds=0,
)
@pytest.fixture
def target(self, task_lock_params: TaskLockParams) -> InMemoryTarget:
return make_in_memory_target(target_key='dummy_key', task_lock_params=task_lock_params)
@pytest.fixture(autouse=True)
def clear_repo(self) -> None:
InMemoryCacheRepository().clear()
def test_dump_and_load_data(self, target: InMemoryTarget) -> None:
dumped = 'dummy_data'
target.dump(dumped)
loaded = target.load()
assert loaded == dumped
def test_exist(self, target: InMemoryTarget) -> None:
assert not target.exists()
target.dump('dummy_data')
assert target.exists()
def test_last_modified_time(self, target: InMemoryTarget) -> None:
input = 'dummy_data'
target.dump(input)
time = target.last_modification_time()
assert isinstance(time, datetime)
sleep(0.1)
another_input = 'another_data'
target.dump(another_input)
another_time = target.last_modification_time()
assert time < another_time
target.remove()
with pytest.raises(ValueError):
assert target.last_modification_time()
================================================
FILE: test/in_memory/test_repository.py
================================================
import time
import pytest
from gokart.in_memory import InMemoryCacheRepository as Repo
dummy_num = 100
class TestInMemoryCacheRepository:
@pytest.fixture
def repo(self) -> Repo:
repo = Repo()
repo.clear()
return repo
def test_set(self, repo: Repo) -> None:
repo.set_value('dummy_key', dummy_num)
assert repo.size == 1
for key, value in repo.get_gen():
assert (key, value) == ('dummy_key', dummy_num)
repo.set_value('another_key', 'another_value')
assert repo.size == 2
def test_get(self, repo: Repo) -> None:
repo.set_value('dummy_key', dummy_num)
repo.set_value('another_key', 'another_value')
"""Raise Error when key doesn't exist."""
with pytest.raises(KeyError):
repo.get_value('not_exist_key')
assert repo.get_value('dummy_key') == dummy_num
assert repo.get_value('another_key') == 'another_value'
def test_empty(self, repo: Repo) -> None:
assert repo.empty()
repo.set_value('dummmy_key', dummy_num)
assert not repo.empty()
def test_has(self, repo: Repo) -> None:
assert not repo.has('dummy_key')
repo.set_value('dummy_key', dummy_num)
assert repo.has('dummy_key')
assert not repo.has('not_exist_key')
def test_remove(self, repo: Repo) -> None:
repo.set_value('dummy_key', dummy_num)
with pytest.raises(AssertionError):
repo.remove('not_exist_key')
repo.remove('dummy_key')
assert not repo.has('dummy_key')
def test_last_modification_time(self, repo: Repo) -> None:
repo.set_value('dummy_key', dummy_num)
date1 = repo.get_last_modification_time('dummy_key')
time.sleep(0.1)
repo.set_value('dummy_key', dummy_num)
date2 = repo.get_last_modification_time('dummy_key')
assert date1 < date2
================================================
FILE: test/slack/__init__.py
================================================
================================================
FILE: test/slack/test_slack_api.py
================================================
import unittest
from logging import getLogger
from unittest import mock
from unittest.mock import MagicMock
from slack_sdk import WebClient
from slack_sdk.web.slack_response import SlackResponse
from testfixtures import LogCapture
import gokart.slack
logger = getLogger(__name__)
def _slack_response(token, data):
return SlackResponse(
client=WebClient(token=token), http_verb='POST', api_url='http://localhost:3000/api.test', req_args={}, data=data, headers={}, status_code=200
)
class TestSlackAPI(unittest.TestCase):
@mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
def test_initialization_with_invalid_token(self, patch):
def _conversations_list(params={}):
return _slack_response(token='invalid', data={'ok': False, 'error': 'error_reason'})
mock_client = MagicMock()
mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
patch.return_value = mock_client
with LogCapture() as log:
gokart.slack.SlackAPI(token='invalid', channel='test', to_user='test user')
log.check(('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel test is not found in public channels.'))
@mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
def test_invalid_channel(self, patch):
def _conversations_list(params={}):
return _slack_response(
token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}}
)
mock_client = MagicMock()
mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
patch.return_value = mock_client
with LogCapture() as log:
gokart.slack.SlackAPI(token='valid', channel='invalid_channel', to_user='test user')
log.check(
('gokart.slack.slack_api', 'WARNING', 'The job will start without slack notification: Channel invalid_channel is not found in public channels.')
)
@mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
def test_send_snippet_with_invalid_token(self, patch):
def _conversations_list(params={}):
return _slack_response(
token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}}
)
def _api_call(method, data={}):
assert method == 'files.upload'
return {'ok': False, 'error': 'error_reason'}
mock_client = MagicMock()
mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
mock_client.api_call = MagicMock(side_effect=_api_call)
patch.return_value = mock_client
with LogCapture() as log:
api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user')
api.send_snippet(comment='test', title='title', content='content')
log.check(
('gokart.slack.slack_api', 'WARNING', 'Failed to send slack notification: Error while uploading file. The error reason is "error_reason".')
)
@mock.patch('gokart.slack.slack_api.slack_sdk.WebClient')
def test_send(self, patch):
def _conversations_list(params={}):
return _slack_response(
token='valid', data={'ok': True, 'channels': [{'name': 'valid', 'id': 'valid_id'}], 'response_metadata': {'next_cursor': ''}}
)
def _api_call(method, data={}):
assert method == 'files.upload'
return {'ok': False, 'error': 'error_reason'}
mock_client = MagicMock()
mock_client.conversations_list = MagicMock(side_effect=_conversations_list)
mock_client.api_call = MagicMock(side_effect=_api_call)
patch.return_value = mock_client
api = gokart.slack.SlackAPI(token='valid', channel='valid', to_user='test user')
api.send_snippet(comment='test', title='title', content='content')
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_build.py
================================================
from __future__ import annotations
import io
import logging
import os
import sys
import unittest
from copy import copy
from typing import Any
if sys.version_info >= (3, 11):
from typing import assert_type
else:
from typing_extensions import assert_type
from unittest.mock import patch
import luigi
import luigi.mock
import gokart
from gokart.build import GokartBuildError, LoggerConfig, TaskDumpConfig, TaskDumpMode, TaskDumpOutputType, process_task_info
from gokart.conflict_prevention_lock.task_lock import TaskLockException
class _DummyTask(gokart.TaskOnKart[str]):
task_namespace = __name__
param: luigi.Parameter = luigi.Parameter()
def output(self):
return self.make_target('./test/dummy.pkl')
def run(self):
self.dump(self.param)
class _DummyTaskTwoOutputs(gokart.TaskOnKart[dict[str, str]]):
task_namespace = __name__
param1: luigi.Parameter = luigi.Parameter()
param2: luigi.Parameter = luigi.Parameter()
def output(self):
return {'out1': self.make_target('./test/dummy1.pkl'), 'out2': self.make_target('./test/dummy2.pkl')}
def run(self):
self.dump(self.param1, 'out1')
self.dump(self.param2, 'out2')
class _DummyFailedTask(gokart.TaskOnKart[Any]):
task_namespace = __name__
def run(self):
raise RuntimeError
class _ParallelRunner(gokart.TaskOnKart[str]):
def requires(self):
return [_DummyTask(param=str(i)) for i in range(10)]
def run(self):
self.dump('done')
class _LoadRequires(gokart.TaskOnKart[str]):
task: gokart.TaskInstanceParameter[gokart.TaskOnKart[str]] = gokart.TaskInstanceParameter()
def requires(self):
return self.task
def run(self):
s = self.load(self.task)
self.dump(s)
class RunTest(unittest.TestCase):
def setUp(self):
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
luigi.configuration.LuigiConfigParser._instance = None
self.config_paths = copy(luigi.configuration.LuigiConfigParser._config_paths)
luigi.mock.MockFileSystem().clear()
os.environ.clear()
def tearDown(self):
luigi.configuration.LuigiConfigParser._config_paths = self.config_paths
os.environ.clear()
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
def test_build(self):
text = 'test'
output = gokart.build(_DummyTask(param=text), reset_register=False)
self.assertEqual(output, text)
def test_build_parallel(self):
output = gokart.build(_ParallelRunner(), reset_register=False, workers=20)
self.assertEqual(output, 'done')
def test_read_config(self):
class _DummyTask(gokart.TaskOnKart[Any]):
task_namespace = 'test_read_config'
param = luigi.Parameter()
def run(self):
self.dump(self.param)
os.environ.setdefault('test_param', 'test')
config_file_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'config', 'test_config.ini')
gokart.utils.add_config(config_file_path)
output = gokart.build(_DummyTask(), reset_register=False)
self.assertIsInstance(output, str)
self.assertEqual(output, 'test')
def test_build_dict_outputs(self):
param_dict = {
'out1': 'test1',
'out2': 'test2',
}
output = gokart.build(_DummyTaskTwoOutputs(param1=param_dict['out1'], param2=param_dict['out2']), reset_register=False)
assert_type(output, dict[str, str])
self.assertEqual(output, param_dict)
def test_failed_task(self):
with self.assertRaises(GokartBuildError):
gokart.build(_DummyFailedTask(), reset_register=False, log_level=logging.CRITICAL)
def test_load_requires(self):
text = 'test'
output = gokart.build(_LoadRequires(task=_DummyTask(param=text)), reset_register=False)
self.assertEqual(output, text)
def test_build_with_child_task_error(self):
class CheckException(Exception):
pass
class FailTask(gokart.TaskOnKart[Any]):
def run(self):
raise CheckException()
t = FailTask()
with self.assertRaises(GokartBuildError) as cm:
gokart.build(t, reset_register=False, log_level=logging.CRITICAL)
e = cm.exception
self.assertEqual(len(e.raised_exceptions), 1)
self.assertIsInstance(e.raised_exceptions[t.make_unique_id()][0], CheckException)
class LoggerConfigTest(unittest.TestCase):
def test_logger_config(self):
for level, enable_expected, disable_expected in (
(logging.INFO, logging.INFO, logging.DEBUG),
(logging.DEBUG, logging.DEBUG, logging.NOTSET),
(logging.CRITICAL, logging.CRITICAL, logging.ERROR),
):
with self.subTest(level=level, enable_expected=enable_expected, disable_expected=disable_expected):
with LoggerConfig(level) as lc:
self.assertTrue(lc.logger.isEnabledFor(enable_expected))
self.assertTrue(not lc.logger.isEnabledFor(disable_expected))
class ProcessTaskInfoTest(unittest.TestCase):
def test_process_task_info(self):
task = _DummyTask(param='test')
for config in (
TaskDumpConfig(mode=TaskDumpMode.TREE, output_type=TaskDumpOutputType.PRINT),
TaskDumpConfig(mode=TaskDumpMode.TABLE, output_type=TaskDumpOutputType.PRINT),
):
with LoggerConfig(level=logging.INFO):
from gokart.build import logger
log_stream = io.StringIO()
handler = logging.StreamHandler(log_stream)
handler.setLevel(logging.INFO)
logger.addHandler(handler)
process_task_info(task, config)
logger.removeHandler(handler)
handler.close()
self.assertIn(member=str(task.make_unique_id()), container=log_stream.getvalue())
class _FailThreeTimesAndSuccessTask(gokart.TaskOnKart[Any]):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.failed_counter = 0
def run(self):
if self.failed_counter < 3:
self.failed_counter += 1
raise TaskLockException()
self.dump('done')
class TestBuildHasLockedTaskException(unittest.TestCase):
def test_build_expo_backoff_when_luigi_failed_due_to_locked_task(self):
gokart.build(_FailThreeTimesAndSuccessTask(), reset_register=False)
class TestBuildFailedAndSchedulingFailed(unittest.TestCase):
def test_build_raises_exception_on_failed_and_scheduling_failed(self):
"""Test that build() raises GokartBuildError when FAILED_AND_SCHEDULING_FAILED occurs"""
# Create a mock result object with FAILED_AND_SCHEDULING_FAILED status
class MockResult:
def __init__(self):
self.status = luigi.LuigiStatusCode.FAILED_AND_SCHEDULING_FAILED
self.summary_text = 'Task failed and scheduling failed'
# Mock luigi.build to return FAILED_AND_SCHEDULING_FAILED status
with patch('luigi.build') as mock_luigi_build:
mock_luigi_build.return_value = MockResult()
# This should now raise GokartBuildError after the fix
with self.assertRaises(GokartBuildError):
gokart.build(_DummyTask(param='test'), reset_register=False, log_level=logging.CRITICAL)
def test_build_not_raises_exception_when_success_with_retry(self):
"""Test that build() does not raise GokartBuildError when task succeeds with retry"""
# Create a mock result object with SUCCESS_WITH_RETRY status
class MockResult:
def __init__(self):
self.status = luigi.LuigiStatusCode.SUCCESS_WITH_RETRY
self.summary_text = 'Task completed successfully after retries'
# Mock _build_task to return a test value directly
with patch('luigi.build') as mock_luigi_build:
mock_luigi_build.return_value = MockResult()
# Create a mock task that will be used by build()
mock_task = _DummyTask(param='test')
# This should not raise GokartBuildError
# The test output will be whatever the mock returns
gokart.build(mock_task, reset_register=False, return_value=False, log_level=logging.CRITICAL)
def test_build_not_raises_exception_on_scheduling_failed_only(self):
"""Test that build() raises GokartBuildError when SCHEDULING_FAILED occurs"""
# Create a mock result object with SCHEDULING_FAILED status
class MockResult:
def __init__(self):
self.status = luigi.LuigiStatusCode.SCHEDULING_FAILED
self.summary_text = 'Task scheduling failed'
# Mock luigi.build to return SCHEDULING_FAILED status
with patch('luigi.build') as mock_luigi_build:
mock_luigi_build.return_value = MockResult()
# This should raise GokartBuildError after the fix
with self.assertRaises(GokartBuildError):
gokart.build(_DummyTask(param='test'), reset_register=False, log_level=logging.CRITICAL)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_cache_unique_id.py
================================================
import os
import unittest
from typing import Any
import luigi
import luigi.mock
import gokart
class _DummyTask(gokart.TaskOnKart[Any]):
def requires(self):
return _DummyTaskDep()
def run(self):
self.dump(self.load())
class _DummyTaskDep(gokart.TaskOnKart[str]):
param: luigi.Parameter = luigi.Parameter()
def run(self):
self.dump(self.param)
class CacheUniqueIDTest(unittest.TestCase):
def setUp(self):
luigi.configuration.LuigiConfigParser._instance = None
luigi.mock.MockFileSystem().clear()
os.environ.clear()
@staticmethod
def _set_param(cls, attr_name: str, param: luigi.Parameter) -> None: # type: ignore
# Luigi 3.8.0+ uses __set_name__ to register _attribute_name on Parameter descriptors.
# When assigning after class creation (bypassing the metaclass), call it manually.
param.__set_name__(cls, attr_name)
setattr(cls, attr_name, param)
def test_cache_unique_id_true(self):
self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='original_param'))
output1 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False)
self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='updated_param'))
output2 = gokart.build(_DummyTask(cache_unique_id=True), reset_register=False)
self.assertEqual(output1, output2)
def test_cache_unique_id_false(self):
self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='original_param'))
output1 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False)
self._set_param(_DummyTaskDep, 'param', luigi.Parameter(default='updated_param'))
output2 = gokart.build(_DummyTask(cache_unique_id=False), reset_register=False)
self.assertNotEqual(output1, output2)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_config_params.py
================================================
import unittest
from typing import Any
import luigi
from luigi.cmdline_parser import CmdlineParser
import gokart
from gokart.config_params import inherits_config_params
def in_parse(cmds, deferred_computation):
"""function copied from luigi: https://github.com/spotify/luigi/blob/e2228418eec60b68ca09a30c878ab26413846847/test/helpers.py"""
with CmdlineParser.global_instance(cmds) as cp:
deferred_computation(cp.get_task_obj())
class ConfigClass(luigi.Config):
param_a = luigi.Parameter(default='config a')
param_b = luigi.Parameter(default='config b')
param_c = luigi.Parameter(default='config c')
@inherits_config_params(ConfigClass)
class Inherited(gokart.TaskOnKart[Any]):
param_a = luigi.Parameter()
param_b = luigi.Parameter(default='overrided')
@inherits_config_params(ConfigClass, parameter_alias={'param_a': 'param_d'})
class Inherited2(gokart.TaskOnKart[Any]):
param_c = luigi.Parameter()
param_d = luigi.Parameter()
class ChildTask(Inherited):
pass
class ChildTaskWithNewParam(Inherited):
param_new = luigi.Parameter()
class ConfigClass2(luigi.Config):
param_a = luigi.Parameter(default='config a from config class 2')
@inherits_config_params(ConfigClass2)
class ChildTaskWithNewConfig(Inherited):
pass
class TestInheritsConfigParam(unittest.TestCase):
def test_inherited_params(self):
# test fill values
in_parse(['Inherited'], lambda task: self.assertEqual(task.param_a, 'config a'))
# test overrided
in_parse(['Inherited'], lambda task: self.assertEqual(task.param_b, 'config b'))
# Command line argument takes precedence over config param
in_parse(['Inherited', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg'))
# Parameters which is not a member of the task will not be set
with self.assertRaises(AttributeError):
in_parse(['Inherited'], lambda task: task.param_c)
# test parameter name alias
in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_c, 'config c'))
in_parse(['Inherited2'], lambda task: self.assertEqual(task.param_d, 'config a'))
def test_child_task(self):
in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_a, 'config a'))
in_parse(['ChildTask'], lambda task: self.assertEqual(task.param_b, 'config b'))
in_parse(['ChildTask', '--param-a', 'command line arg'], lambda task: self.assertEqual(task.param_a, 'command line arg'))
with self.assertRaises(AttributeError):
in_parse(['ChildTask'], lambda task: task.param_c)
def test_child_override(self):
in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_a, 'config a from config class 2'))
in_parse(['ChildTaskWithNewConfig'], lambda task: self.assertEqual(task.param_b, 'config b'))
================================================
FILE: test/test_explicit_bool_parameter.py
================================================
import unittest
from typing import Any
import luigi
import luigi.mock
from luigi.cmdline_parser import CmdlineParser
import gokart
def in_parse(cmds, deferred_computation):
with CmdlineParser.global_instance(cmds) as cp:
deferred_computation(cp.get_task_obj())
class WithDefaultTrue(gokart.TaskOnKart[Any]):
param = gokart.ExplicitBoolParameter(default=True)
class WithDefaultFalse(gokart.TaskOnKart[Any]):
param = gokart.ExplicitBoolParameter(default=False)
class ExplicitParsing(gokart.TaskOnKart[Any]):
param = gokart.ExplicitBoolParameter()
def run(self):
ExplicitParsing._param = self.param # type: ignore
class TestExplicitBoolParameter(unittest.TestCase):
def test_bool_default(self):
self.assertTrue(WithDefaultTrue().param)
self.assertFalse(WithDefaultFalse().param)
def test_parse_param(self):
in_parse(['ExplicitParsing', '--param', 'true'], lambda task: self.assertTrue(task.param))
in_parse(['ExplicitParsing', '--param', 'false'], lambda task: self.assertFalse(task.param))
in_parse(['ExplicitParsing', '--param', 'True'], lambda task: self.assertTrue(task.param))
in_parse(['ExplicitParsing', '--param', 'False'], lambda task: self.assertFalse(task.param))
def test_missing_parameter(self):
with self.assertRaises(luigi.parameter.MissingParameterException):
in_parse(['ExplicitParsing'], lambda: True)
def test_value_error(self):
with self.assertRaises(ValueError):
in_parse(['ExplicitParsing', '--param', 'Foo'], lambda: True)
def test_expected_one_argment_error(self):
# argparse throw "expected one argument" error
with self.assertRaises(SystemExit):
in_parse(['ExplicitParsing', '--param'], lambda: True)
================================================
FILE: test/test_gcs_config.py
================================================
import os
import unittest
from unittest.mock import MagicMock, patch
from gokart.gcs_config import GCSConfig
class TestGCSConfig(unittest.TestCase):
def test_get_gcs_client_without_gcs_credential_name(self):
mock = MagicMock()
os.environ['env_name'] = ''
with patch('luigi.contrib.gcs.GCSClient', mock):
GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
self.assertEqual(dict(oauth_credentials=None), mock.call_args[1])
def test_get_gcs_client_with_file_path(self):
mock = MagicMock()
file_path = 'test.json'
os.environ['env_name'] = file_path
with patch('luigi.contrib.gcs.GCSClient'):
with patch('google.oauth2.service_account.Credentials.from_service_account_file', mock):
with patch('os.path.isfile', return_value=True):
GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
self.assertEqual(file_path, mock.call_args[0][0])
def test_get_gcs_client_with_json(self):
mock = MagicMock()
json_str = '{"test": 1}'
os.environ['env_name'] = json_str
with patch('luigi.contrib.gcs.GCSClient'):
with patch('google.oauth2.service_account.Credentials.from_service_account_info', mock):
GCSConfig(gcs_credential_name='env_name')._get_gcs_client()
self.assertEqual(dict(test=1), mock.call_args[0][0])
================================================
FILE: test/test_gcs_obj_metadata_client.py
================================================
from __future__ import annotations
import datetime
import unittest
from typing import Any
from unittest.mock import MagicMock, patch
import gokart
from gokart.gcs_obj_metadata_client import GCSObjectMetadataClient
from gokart.required_task_output import RequiredTaskOutput
from gokart.target import TargetOnKart
class _DummyTaskOnKart(gokart.TaskOnKart[str]):
task_namespace = __name__
def run(self):
self.dump('Dummy TaskOnKart')
class TestGCSObjectMetadataClient(unittest.TestCase):
def setUp(self):
self.task_params: dict[str, str] = {
'param1': 'a' * 1000,
'param2': str(1000),
'param3': str({'key1': 'value1', 'key2': True, 'key3': 2}),
'param4': str([1, 2, 3, 4, 5]),
'param5': str(datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5)),
'param6': '',
}
self.custom_labels: dict[str, Any] = {
'created_at': datetime.datetime(year=2025, month=1, day=2, hour=3, minute=4, second=5),
'created_by': 'hoge fuga',
'empty': True,
'try_num': 3,
}
self.task_params_with_conflicts = {
'empty': 'False',
'created_by': 'fuga hoge',
'param1': 'a' * 10,
}
def test_normalize_labels_not_empty(self):
got = GCSObjectMetadataClient._normalize_labels(None)
self.assertEqual(got, {})
def test_normalize_labels_has_value(self):
got = GCSObjectMetadataClient._normalize_labels(self.task_params)
self.assertIsInstance(got, dict)
self.assertIsInstance(got, dict)
self.assertIn('param1', got)
self.assertIn('param2', got)
self.assertIn('param3', got)
self.assertIn('param4', got)
self.assertIn('param5', got)
self.assertIn('param6', got)
def test_get_patched_obj_metadata_only_task_params(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=None)
self.assertIsInstance(got, dict)
self.assertIn('param1', got)
self.assertIn('param2', got)
self.assertIn('param3', got)
self.assertIn('param4', got)
self.assertIn('param5', got)
self.assertNotIn('param6', got)
def test_get_patched_obj_metadata_only_custom_labels(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=None, custom_labels=self.custom_labels)
self.assertIsInstance(got, dict)
self.assertIn('created_at', got)
self.assertIn('created_by', got)
self.assertIn('empty', got)
self.assertIn('try_num', got)
def test_get_patched_obj_metadata_with_both_task_params_and_custom_labels(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params, custom_labels=self.custom_labels)
self.assertIsInstance(got, dict)
self.assertIn('param1', got)
self.assertIn('param2', got)
self.assertIn('param3', got)
self.assertIn('param4', got)
self.assertIn('param5', got)
self.assertNotIn('param6', got)
self.assertIn('created_at', got)
self.assertIn('created_by', got)
self.assertIn('empty', got)
self.assertIn('try_num', got)
def test_get_patched_obj_metadata_with_exceeded_size_metadata(self):
size_exceeded_task_params = {
'param1': 'a' * 5000,
'param2': 'b' * 5000,
}
want = {
'param1': 'a' * 5000,
}
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=size_exceeded_task_params)
self.assertEqual(got, want)
def test_get_patched_obj_metadata_with_conflicts(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata({}, task_params=self.task_params_with_conflicts, custom_labels=self.custom_labels)
self.assertIsInstance(got, dict)
self.assertIn('created_at', got)
self.assertIn('created_by', got)
self.assertIn('empty', got)
self.assertIn('try_num', got)
self.assertEqual(got['empty'], 'True')
self.assertEqual(got['created_by'], 'hoge fuga')
self.assertEqual(got['param1'], 'a' * 10)
def test_get_patched_obj_metadata_with_required_task_outputs(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata(
{},
required_task_outputs=[
RequiredTaskOutput(task_name='task1', output_path='path/to/output1'),
],
)
self.assertIsInstance(got, dict)
self.assertIn('__required_task_outputs', got)
self.assertEqual(got['__required_task_outputs'], '[{"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}]')
def test_get_patched_obj_metadata_with_nested_required_task_outputs(self):
got = GCSObjectMetadataClient._get_patched_obj_metadata(
{},
required_task_outputs={
'nested_task': {'nest': RequiredTaskOutput(task_name='task1', output_path='path/to/output1')},
},
)
self.assertIsInstance(got, dict)
self.assertIn('__required_task_outputs', got)
self.assertEqual(
got['__required_task_outputs'], '{"nested_task": {"nest": {"__gokart_task_name": "task1", "__gokart_output_path": "path/to/output1"}}}'
)
def test_adjust_gcs_metadata_limit_size_runtime_error(self):
large_labels = {}
for i in range(100):
large_labels[f'key_{i}'] = 'x' * 1000
GCSObjectMetadataClient._adjust_gcs_metadata_limit_size(large_labels)
class TestGokartTask(unittest.TestCase):
@patch.object(_DummyTaskOnKart, '_get_output_target')
def test_mock_target_on_kart(self, mock_get_output_target):
mock_target = MagicMock(spec=TargetOnKart)
mock_get_output_target.return_value = mock_target
task = _DummyTaskOnKart()
task.dump({'key': 'value'}, mock_target)
mock_target.dump.assert_called_once_with(
{'key': 'value'}, lock_at_dump=task._lock_at_dump, task_params={}, custom_labels=None, required_task_outputs=[]
)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_info.py
================================================
import unittest
from unittest.mock import patch
import luigi
import luigi.mock
from luigi.mock import MockFileSystem, MockTarget
import gokart
import gokart.info
from test.tree.test_task_info import _DoubleLoadSubTask, _SubTask, _Task
class TestInfo(unittest.TestCase):
def setUp(self) -> None:
MockFileSystem().clear()
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
def tearDown(self) -> None:
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_pending(self):
task = _Task(param=1, sub=_SubTask(param=2))
# check before running
tree = gokart.info.make_tree_info(task)
expected = r"""
└─-\(PENDING\) _Task\[[a-z0-9]*\]
└─-\(PENDING\) _SubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_complete(self):
task = _Task(param=1, sub=_SubTask(param=2))
# check after sub task runs
gokart.build(task, reset_register=False)
tree = gokart.info.make_tree_info(task)
expected = r"""
└─-\(COMPLETE\) _Task\[[a-z0-9]*\]
└─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_abbreviation(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)
# check after sub task runs
gokart.build(task, reset_register=False)
tree = gokart.info.make_tree_info(task)
expected = r"""
└─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]
\|--\(COMPLETE\) _Task\[[a-z0-9]*\]
\| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]
└─-\(COMPLETE\) _Task\[[a-z0-9]*\]
└─- \.\.\.$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_not_compress(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)
# check after sub task runs
gokart.build(task, reset_register=False)
tree = gokart.info.make_tree_info(task, abbr=False)
expected = r"""
└─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]
\|--\(COMPLETE\) _Task\[[a-z0-9]*\]
\| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]
└─-\(COMPLETE\) _Task\[[a-z0-9]*\]
└─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_not_compress_ignore_task(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)
# check after sub task runs
gokart.build(task, reset_register=False)
tree = gokart.info.make_tree_info(task, abbr=False, ignore_task_names=['_Task'])
expected = r"""
└─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_large_data_fram_processor.py
================================================
import os
import shutil
import unittest
import numpy as np
import pandas as pd
from gokart.target import LargeDataFrameProcessor
from test.util import _get_temporary_directory
class LargeDataFrameProcessorTest(unittest.TestCase):
def setUp(self):
self.temporary_directory = _get_temporary_directory()
def tearDown(self):
shutil.rmtree(self.temporary_directory, ignore_errors=True)
def test_save_and_load(self):
file_path = os.path.join(self.temporary_directory, 'test.zip')
df = pd.DataFrame(dict(data=np.random.uniform(0, 1, size=int(1e6))))
processor = LargeDataFrameProcessor(max_byte=int(1e6))
processor.save(df, file_path)
loaded = processor.load(file_path)
pd.testing.assert_frame_equal(loaded, df, check_like=True)
def test_save_and_load_empty(self):
file_path = os.path.join(self.temporary_directory, 'test_with_empty.zip')
df = pd.DataFrame()
processor = LargeDataFrameProcessor(max_byte=int(1e6))
processor.save(df, file_path)
loaded = processor.load(file_path)
pd.testing.assert_frame_equal(loaded, df, check_like=True)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_list_task_instance_parameter.py
================================================
import unittest
from typing import Any
import luigi
import gokart
from gokart import TaskOnKart
class _DummySubTask(TaskOnKart[Any]):
task_namespace = __name__
pass
class _DummyTask(TaskOnKart[Any]):
task_namespace = __name__
param: luigi.IntParameter = luigi.IntParameter()
task: gokart.TaskInstanceParameter[_DummySubTask] = gokart.TaskInstanceParameter(default=_DummySubTask())
class ListTaskInstanceParameterTest(unittest.TestCase):
def setUp(self):
_DummyTask.clear_instance_cache()
def test_serialize_and_parse(self):
original = [_DummyTask(param=3), _DummyTask(param=3)]
s = gokart.ListTaskInstanceParameter().serialize(original)
parsed = gokart.ListTaskInstanceParameter().parse(s)
self.assertEqual(parsed[0].task_id, original[0].task_id)
self.assertEqual(parsed[1].task_id, original[1].task_id)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_mypy.py
================================================
import tempfile
import unittest
from mypy import api
from test.config import PYPROJECT_TOML
class TestMyMypyPlugin(unittest.TestCase):
def test_plugin_no_issue(self):
test_code = """
import luigi
from luigi import Parameter
import gokart
import datetime
class MyTask(gokart.TaskOnKart):
foo: int = luigi.IntParameter() # type: ignore
bar: str = luigi.Parameter() # type: ignore
baz: bool = gokart.ExplicitBoolParameter()
qux: str = Parameter()
# https://github.com/m3dev/gokart/issues/395
datetime: datetime.datetime = luigi.DateMinuteParameter(interval=10, default=datetime.datetime(2021, 1, 1))
# TaskOnKart parameters:
# - `complete_check_at_run`
MyTask(foo=1, bar='bar', baz=False, qux='qux', complete_check_at_run=False)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
stdout, stderr, exitcode = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertEqual(exitcode, 0, f'mypy plugin error occurred:\nstdout: {stdout}\nstderr: {stderr}')
self.assertIn('Success: no issues found', stdout)
def test_plugin_invalid_arg(self):
test_code = """
import luigi
import gokart
class MyTask(gokart.TaskOnKart):
foo: int = luigi.IntParameter() # type: ignore
bar: str = luigi.Parameter() # type: ignore
baz: bool = gokart.ExplicitBoolParameter()
# issue: foo is int
# not issue: bar is missing, because it can be set by config file.
# TaskOnKart parameters:
# - `complete_check_at_run`
MyTask(foo='1', baz='not bool', complete_check_at_run='not bool')
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
stdout, stderr, exitcode = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
self.assertEqual(exitcode, 1, f'mypy plugin error not occurred:\nstdout: {stdout}\nstderr: {stderr}')
self.assertIn('error: Argument "foo" to "MyTask" has incompatible type "str"; expected "int" [arg-type]', stdout)
self.assertIn('error: Argument "baz" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', stdout)
self.assertIn('error: Argument "complete_check_at_run" to "MyTask" has incompatible type "str"; expected "bool" [arg-type]', stdout)
self.assertIn('Found 3 errors in 1 file (checked 1 source file)', stdout)
================================================
FILE: test/test_pandas_type_check_framework.py
================================================
from __future__ import annotations
import logging
import unittest
from logging import getLogger
from typing import Any
from unittest.mock import patch
import luigi
import pandas as pd
from luigi.mock import MockFileSystem, MockTarget
import gokart
from gokart.build import GokartBuildError
from gokart.pandas_type_config import PandasTypeConfig
logger = getLogger(__name__)
class TestPandasTypeConfig(PandasTypeConfig):
task_namespace = 'test_pandas_type_check_framework'
@classmethod
def type_dict(cls) -> dict[str, Any]:
return {'system_cd': int}
class _DummyFailTask(gokart.TaskOnKart[pd.DataFrame]):
task_namespace = 'test_pandas_type_check_framework'
rerun: luigi.BoolParameter = luigi.BoolParameter(default=True, significant=False)
def output(self):
return self.make_target('dummy.pkl')
def run(self):
df = pd.DataFrame(dict(system_cd=['1']))
self.dump(df)
class _DummyFailWithNoneTask(gokart.TaskOnKart[pd.DataFrame]):
task_namespace = 'test_pandas_type_check_framework'
rerun: luigi.BoolParameter = luigi.BoolParameter(default=True, significant=False)
def output(self):
return self.make_target('dummy.pkl')
def run(self):
df = pd.DataFrame(dict(system_cd=[1, None]))
self.dump(df)
class _DummySuccessTask(gokart.TaskOnKart[pd.DataFrame]):
task_namespace = 'test_pandas_type_check_framework'
rerun: luigi.BoolParameter = luigi.BoolParameter(default=True, significant=False)
def output(self):
return self.make_target('dummy.pkl')
def run(self):
df = pd.DataFrame(dict(system_cd=[1]))
self.dump(df)
class TestPandasTypeCheckFramework(unittest.TestCase):
def setUp(self) -> None:
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
MockFileSystem().clear()
# same way as luigi https://github.com/spotify/luigi/blob/fe7ecf4acf7cf4c084bd0f32162c8e0721567630/test/helpers.py#L175
self._stashed_reg = luigi.task_register.Register._get_reg()
def tearDown(self) -> None:
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
luigi.task_register.Register._set_reg(self._stashed_reg)
@patch('sys.argv', new=['main', 'test_pandas_type_check_framework._DummyFailTask', '--log-level=CRITICAL', '--local-scheduler', '--no-lock'])
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_fail_with_gokart_run(self):
with self.assertRaises(SystemExit) as exit_code:
gokart.run()
self.assertNotEqual(exit_code.exception.code, 0) # raise Error
def test_fail(self):
with self.assertRaises(GokartBuildError):
gokart.build(_DummyFailTask(), log_level=logging.CRITICAL)
def test_fail_with_None(self):
with self.assertRaises(GokartBuildError):
gokart.build(_DummyFailWithNoneTask(), log_level=logging.CRITICAL)
def test_success(self):
gokart.build(_DummySuccessTask())
# no error
================================================
FILE: test/test_pandas_type_config.py
================================================
from __future__ import annotations
from datetime import date, datetime
from typing import Any
from unittest import TestCase
import numpy as np
import pandas as pd
from gokart import PandasTypeConfig
from gokart.pandas_type_config import PandasTypeError
class _DummyPandasTypeConfig(PandasTypeConfig):
@classmethod
def type_dict(cls) -> dict[str, Any]:
return {'int_column': int, 'datetime_column': datetime, 'array_column': np.ndarray}
class TestPandasTypeConfig(TestCase):
def test_int_fail(self):
df = pd.DataFrame(dict(int_column=['1']))
with self.assertRaises(PandasTypeError):
_DummyPandasTypeConfig().check(df)
def test_int_success(self):
df = pd.DataFrame(dict(int_column=[1]))
_DummyPandasTypeConfig().check(df)
def test_datetime_fail(self):
df = pd.DataFrame(dict(datetime_column=[date(2019, 1, 12)]))
with self.assertRaises(PandasTypeError):
_DummyPandasTypeConfig().check(df)
def test_datetime_success(self):
df = pd.DataFrame(dict(datetime_column=[datetime(2019, 1, 12, 0, 0, 0)]))
_DummyPandasTypeConfig().check(df)
def test_array_fail(self):
df = pd.DataFrame(dict(array_column=[[1, 2]]))
with self.assertRaises(PandasTypeError):
_DummyPandasTypeConfig().check(df)
def test_array_success(self):
df = pd.DataFrame(dict(array_column=[np.array([1, 2])]))
_DummyPandasTypeConfig().check(df)
================================================
FILE: test/test_restore_task_by_id.py
================================================
import unittest
from typing import Any
from unittest.mock import patch
import luigi
import luigi.mock
import gokart
class _SubDummyTask(gokart.TaskOnKart[str]):
task_namespace = __name__
param: luigi.IntParameter = luigi.IntParameter()
def run(self):
self.dump('test')
class _DummyTask(gokart.TaskOnKart[str]):
task_namespace = __name__
sub_task: gokart.TaskInstanceParameter[gokart.TaskOnKart[Any]] = gokart.TaskInstanceParameter()
def output(self):
return self.make_target('test.txt')
def run(self):
self.dump('test')
class RestoreTaskByIDTest(unittest.TestCase):
def setUp(self) -> None:
luigi.mock.MockFileSystem().clear()
@patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
def test(self):
task = _DummyTask(sub_task=_SubDummyTask(param=10))
luigi.build([task], local_scheduler=True, log_level='CRITICAL')
unique_id = task.make_unique_id()
restored = _DummyTask.restore(unique_id)
self.assertTrue(task.make_unique_id(), restored.make_unique_id())
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_run.py
================================================
import os
import unittest
from typing import Any
from unittest.mock import MagicMock, patch
import luigi
import luigi.mock
import gokart
from gokart.run import _try_to_send_event_summary_to_slack
class _DummyTask(gokart.TaskOnKart[Any]):
task_namespace = __name__
param: luigi.StrParameter = luigi.StrParameter()
class RunTest(unittest.TestCase):
def setUp(self):
luigi.configuration.LuigiConfigParser._instance = None
luigi.mock.MockFileSystem().clear()
os.environ.clear()
@patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--param', 'test', '--log-level=CRITICAL', '--local-scheduler'])
def test_run(self):
config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
os.environ.setdefault('test_param', 'test')
with self.assertRaises(SystemExit) as exit_code:
gokart.run()
self.assertEqual(exit_code.exception.code, 0)
@patch('sys.argv', new=['main', f'{__name__}._DummyTask', '--log-level=CRITICAL', '--local-scheduler'])
def test_run_with_undefined_environ(self):
config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
with self.assertRaises(luigi.parameter.MissingParameterException):
gokart.run()
@patch(
'sys.argv',
new=[
'main',
'--tree-info-mode=simple',
'--tree-info-output-path=tree.txt',
f'{__name__}._DummyTask',
'--param',
'test',
'--log-level=CRITICAL',
'--local-scheduler',
],
)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: luigi.mock.MockTarget(path, **kwargs))
def test_run_tree_info(self):
config_file_path = os.path.join(os.path.dirname(__name__), 'config', 'test_config.ini')
luigi.configuration.LuigiConfigParser.add_config_path(config_file_path)
os.environ.setdefault('test_param', 'test')
tree_info = gokart.tree_info(mode='simple', output_path='tree.txt')
with self.assertRaises(SystemExit):
gokart.run()
self.assertTrue(gokart.make_tree_info(_DummyTask(param='test')), tree_info.output().load())
@patch('gokart.make_tree_info')
def test_try_to_send_event_summary_to_slack(self, make_tree_info_mock: MagicMock) -> None:
event_aggregator_mock = MagicMock()
event_aggregator_mock.get_summury.return_value = f'{__name__}._DummyTask'
event_aggregator_mock.get_event_list.return_value = f'{__name__}._DummyTask:[]'
make_tree_info_mock.return_value = 'tree'
def get_content(content: str, **kwargs: Any) -> None:
self.output = content
slack_api_mock = MagicMock()
slack_api_mock.send_snippet.side_effect = get_content
cmdline_args = [f'{__name__}._DummyTask', '--param', 'test']
with patch('gokart.slack.SlackConfig.send_tree_info', True):
_try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args)
expects = os.linesep.join(['===== Event List ====', event_aggregator_mock.get_event_list(), os.linesep, '==== Tree Info ====', 'tree'])
results = self.output
self.assertEqual(expects, results)
cmdline_args = [f'{__name__}._DummyTask', '--param', 'test']
with patch('gokart.slack.SlackConfig.send_tree_info', False):
_try_to_send_event_summary_to_slack(slack_api_mock, event_aggregator_mock, cmdline_args)
expects = os.linesep.join(
[
'===== Event List ====',
event_aggregator_mock.get_event_list(),
os.linesep,
'==== Tree Info ====',
'Please add SlackConfig.send_tree_info to include tree-info',
]
)
results = self.output
self.assertEqual(expects, results)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_s3_config.py
================================================
import unittest
from gokart.s3_config import S3Config
class TestS3Config(unittest.TestCase):
def test_get_same_s3_client(self):
client_a = S3Config().get_s3_client()
client_b = S3Config().get_s3_client()
self.assertEqual(client_a, client_b)
================================================
FILE: test/test_s3_zip_client.py
================================================
import os
import shutil
import unittest
import boto3
from moto import mock_aws
from gokart.s3_zip_client import S3ZipClient
from test.util import _get_temporary_directory
class TestS3ZipClient(unittest.TestCase):
def setUp(self):
self.temporary_directory = _get_temporary_directory()
def tearDown(self):
shutil.rmtree(self.temporary_directory, ignore_errors=True)
# remove temporary zip archive if exists.
if os.path.exists(f'{self.temporary_directory}.zip'):
os.remove(f'{self.temporary_directory}.zip')
@mock_aws
def test_make_archive(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
file_path = os.path.join('s3://test/', 'test.zip')
temporary_directory = self.temporary_directory
zip_client = S3ZipClient(file_path=file_path, temporary_directory=temporary_directory)
# raise error if temporary directory does not exist.
with self.assertRaises(FileNotFoundError):
zip_client.make_archive()
# run without error because temporary directory exists.
os.makedirs(temporary_directory, exist_ok=True)
zip_client.make_archive()
@mock_aws
def test_unpack_archive(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
file_path = os.path.join('s3://test/', 'test.zip')
in_temporary_directory = os.path.join(self.temporary_directory, 'in', 'dummy')
out_temporary_directory = os.path.join(self.temporary_directory, 'out', 'dummy')
# make dummy zip file.
os.makedirs(in_temporary_directory, exist_ok=True)
in_zip_client = S3ZipClient(file_path=file_path, temporary_directory=in_temporary_directory)
in_zip_client.make_archive()
# load dummy zip file.
out_zip_client = S3ZipClient(file_path=file_path, temporary_directory=out_temporary_directory)
self.assertFalse(os.path.exists(out_temporary_directory))
out_zip_client.unpack_archive()
================================================
FILE: test/test_serializable_parameter.py
================================================
import json
import tempfile
from dataclasses import asdict, dataclass
from typing import Any
import luigi
import pytest
from luigi.cmdline_parser import CmdlineParser
from mypy import api
from gokart import SerializableParameter, TaskOnKart
from test.config import PYPROJECT_TOML
@dataclass(frozen=True)
class Config:
foo: int
bar: str
def gokart_serialize(self) -> str:
# dict is ordered in Python 3.7+
return json.dumps(asdict(self))
@classmethod
def gokart_deserialize(cls, s: str) -> 'Config':
return cls(**json.loads(s))
class SerializableParameterWithOutDefault(TaskOnKart[Any]):
task_namespace = __name__
config: SerializableParameter[Config] = SerializableParameter(object_type=Config)
def run(self):
self.dump(self.config)
class SerializableParameterWithDefault(TaskOnKart[Any]):
task_namespace = __name__
config: SerializableParameter[Config] = SerializableParameter(object_type=Config, default=Config(foo=1, bar='bar'))
def run(self):
self.dump(self.config)
class TestSerializableParameter:
def test_default(self):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithDefault']) as cp:
assert cp.get_task_obj().config == Config(foo=1, bar='bar')
def test_parse_param(self):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', '{"foo": 100, "bar": "val"}']) as cp:
assert cp.get_task_obj().config == Config(foo=100, bar='val')
def test_missing_parameter(self):
with pytest.raises(luigi.parameter.MissingParameterException):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault']) as cp:
cp.get_task_obj()
def test_value_error(self):
with pytest.raises(ValueError):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config', 'Foo']) as cp:
cp.get_task_obj()
def test_expected_one_argument_error(self):
with pytest.raises(SystemExit):
with CmdlineParser.global_instance([f'{__name__}.SerializableParameterWithOutDefault', '--config']) as cp:
cp.get_task_obj()
def test_mypy(self):
"""check invalid object cannot used for SerializableParameter"""
test_code = """
import gokart
class InvalidClass:
...
gokart.SerializableParameter(object_type=InvalidClass)
"""
with tempfile.NamedTemporaryFile(suffix='.py') as test_file:
test_file.write(test_code.encode('utf-8'))
test_file.flush()
result = api.run(['--no-incremental', '--cache-dir=/dev/null', '--config-file', str(PYPROJECT_TOML), test_file.name])
assert 'Value of type variable "S" of "SerializableParameter" cannot be "InvalidClass" [type-var]' in result[0]
================================================
FILE: test/test_target.py
================================================
import io
import os
import shutil
import unittest
from datetime import datetime
from unittest.mock import patch
import boto3
import numpy as np
import pandas as pd
from matplotlib import pyplot
from moto import mock_aws
from gokart.file_processor.base import _ChunkedLargeFileReader
from gokart.target import make_model_target, make_target
from test.util import _get_temporary_directory
class LocalTargetTest(unittest.TestCase):
def setUp(self):
self.temporary_directory = _get_temporary_directory()
def tearDown(self):
shutil.rmtree(self.temporary_directory, ignore_errors=True)
def test_save_and_load_pickle_file(self):
obj = 1
file_path = os.path.join(self.temporary_directory, 'test.pkl')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
with unittest.mock.patch('gokart.file_processor.base._ChunkedLargeFileReader', wraps=_ChunkedLargeFileReader) as monkey:
loaded = target.load()
monkey.assert_called()
self.assertEqual(loaded, obj)
def test_save_and_load_text_file(self):
obj = 1
file_path = os.path.join(self.temporary_directory, 'test.txt')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
self.assertEqual(loaded, [str(obj)], msg='should save an object as List[str].')
def test_save_and_load_gzip(self):
obj = 1
file_path = os.path.join(self.temporary_directory, 'test.gz')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
self.assertEqual(loaded, [str(obj)], msg='should save an object as List[str].')
def test_save_and_load_npz(self):
obj = np.ones(shape=10, dtype=np.float32)
file_path = os.path.join(self.temporary_directory, 'test.npz')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
np.testing.assert_almost_equal(obj, loaded)
def test_save_and_load_figure(self):
figure_binary = io.BytesIO()
pd.DataFrame(dict(x=range(10), y=range(10))).plot.scatter(x='x', y='y')
pyplot.savefig(figure_binary)
figure_binary.seek(0)
file_path = os.path.join(self.temporary_directory, 'test.png')
target = make_target(file_path=file_path, unique_id=None)
target.dump(figure_binary.read())
loaded = target.load()
self.assertGreater(len(loaded), 1000) # any binary
def test_save_and_load_csv(self):
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join(self.temporary_directory, 'test.csv')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
pd.testing.assert_frame_equal(loaded, obj)
def test_save_and_load_tsv(self):
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join(self.temporary_directory, 'test.tsv')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
pd.testing.assert_frame_equal(loaded, obj)
def test_save_and_load_parquet(self):
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join(self.temporary_directory, 'test.parquet')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
pd.testing.assert_frame_equal(loaded, obj)
def test_save_and_load_feather(self):
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=pd.Index([33, 44], name='object_index'))
file_path = os.path.join(self.temporary_directory, 'test.feather')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
pd.testing.assert_frame_equal(loaded, obj)
def test_save_and_load_feather_without_store_index_in_feather(self):
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]), index=pd.Index([33, 44], name='object_index')).reset_index()
file_path = os.path.join(self.temporary_directory, 'test.feather')
target = make_target(file_path=file_path, unique_id=None, store_index_in_feather=False)
target.dump(obj)
loaded = target.load()
pd.testing.assert_frame_equal(loaded, obj)
def test_last_modified_time(self):
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join(self.temporary_directory, 'test.csv')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
t = target.last_modification_time()
self.assertIsInstance(t, datetime)
def test_last_modified_time_without_file(self):
file_path = os.path.join(self.temporary_directory, 'test.csv')
target = make_target(file_path=file_path, unique_id=None)
with self.assertRaises(FileNotFoundError):
target.last_modification_time()
def test_save_pandas_series(self):
obj = pd.Series(data=[1, 2], name='column_name')
file_path = os.path.join(self.temporary_directory, 'test.csv')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
pd.testing.assert_series_equal(loaded['column_name'], obj)
def test_dump_with_lock(self):
with patch('gokart.target.wrap_dump_with_lock') as wrap_with_lock_mock:
obj = 1
file_path = os.path.join(self.temporary_directory, 'test.pkl')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj, lock_at_dump=True)
wrap_with_lock_mock.assert_called_once()
def test_dump_without_lock(self):
with patch('gokart.target.wrap_dump_with_lock') as wrap_with_lock_mock:
obj = 1
file_path = os.path.join(self.temporary_directory, 'test.pkl')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj, lock_at_dump=False)
wrap_with_lock_mock.assert_not_called()
class S3TargetTest(unittest.TestCase):
@mock_aws
def test_save_on_s3(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
obj = 1
file_path = os.path.join('s3://test/', 'test.pkl')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
self.assertEqual(loaded, obj)
@mock_aws
def test_last_modified_time(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
obj = 1
file_path = os.path.join('s3://test/', 'test.pkl')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
t = target.last_modification_time()
self.assertIsInstance(t, datetime)
@mock_aws
def test_last_modified_time_without_file(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
file_path = os.path.join('s3://test/', 'test.pkl')
target = make_target(file_path=file_path, unique_id=None)
with self.assertRaises(FileNotFoundError):
target.last_modification_time()
@mock_aws
def test_save_on_s3_feather(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join('s3://test/', 'test.feather')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
pd.testing.assert_frame_equal(loaded, obj)
@mock_aws
def test_save_on_s3_parquet(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
obj = pd.DataFrame(dict(a=[1, 2], b=[3, 4]))
file_path = os.path.join('s3://test/', 'test.parquet')
target = make_target(file_path=file_path, unique_id=None)
target.dump(obj)
loaded = target.load()
pd.testing.assert_frame_equal(loaded, obj)
class ModelTargetTest(unittest.TestCase):
def setUp(self):
self.temporary_directory = _get_temporary_directory()
def tearDown(self):
shutil.rmtree(self.temporary_directory, ignore_errors=True)
@staticmethod
def _save_function(obj, path):
make_target(file_path=path).dump(obj)
@staticmethod
def _load_function(path):
return make_target(file_path=path).load()
def test_model_target_on_local(self):
obj = 1
file_path = os.path.join(self.temporary_directory, 'test.zip')
target = make_model_target(
file_path=file_path, temporary_directory=self.temporary_directory, save_function=self._save_function, load_function=self._load_function
)
target.dump(obj)
loaded = target.load()
self.assertEqual(loaded, obj)
@mock_aws
def test_model_target_on_s3(self):
conn = boto3.resource('s3', region_name='us-east-1')
conn.create_bucket(Bucket='test')
obj = 1
file_path = os.path.join('s3://test/', 'test.zip')
target = make_model_target(
file_path=file_path, temporary_directory=self.temporary_directory, save_function=self._save_function, load_function=self._load_function
)
target.dump(obj)
loaded = target.load()
self.assertEqual(loaded, obj)
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_task_instance_parameter.py
================================================
import unittest
from typing import Any
import luigi
import gokart
from gokart import ListTaskInstanceParameter, TaskInstanceParameter, TaskOnKart
class _DummySubTask(TaskOnKart[Any]):
task_namespace = __name__
pass
class _DummyCorrectSubClassTask(_DummySubTask):
task_namespace = __name__
pass
class _DummyInvalidSubClassTask(TaskOnKart[Any]):
task_namespace = __name__
pass
class _DummyTask(TaskOnKart[Any]):
task_namespace = __name__
param: luigi.IntParameter = luigi.IntParameter()
task: TaskInstanceParameter[_DummySubTask] = TaskInstanceParameter(default=_DummySubTask())
class _DummyListTask(TaskOnKart[Any]):
task_namespace = __name__
param: luigi.IntParameter = luigi.IntParameter()
task: ListTaskInstanceParameter[_DummySubTask] = ListTaskInstanceParameter(default=[_DummySubTask(), _DummySubTask()])
class TaskInstanceParameterTest(unittest.TestCase):
def setUp(self):
_DummyTask.clear_instance_cache()
def test_serialize_and_parse(self):
original = _DummyTask(param=2)
s = gokart.TaskInstanceParameter().serialize(original)
parsed = gokart.TaskInstanceParameter().parse(s)
self.assertEqual(parsed.task_id, original.task_id)
def test_serialize_and_parse_list_params(self):
original = _DummyListTask(param=2)
s = gokart.TaskInstanceParameter().serialize(original)
parsed = gokart.TaskInstanceParameter().parse(s)
self.assertEqual(parsed.task_id, original.task_id)
def test_invalid_class(self):
self.assertRaises(TypeError, lambda: gokart.TaskInstanceParameter(expected_type=1)) # type: ignore
def test_params_with_correct_param_type(self):
class _DummyPipelineA(TaskOnKart[Any]):
task_namespace = __name__
subtask: gokart.TaskInstanceParameter[_DummySubTask] = gokart.TaskInstanceParameter(expected_type=_DummySubTask)
task = _DummyPipelineA(subtask=_DummyCorrectSubClassTask())
self.assertEqual(task.requires()['subtask'], _DummyCorrectSubClassTask()) # type: ignore
def test_params_with_invalid_param_type(self):
class _DummyPipelineB(TaskOnKart[Any]):
task_namespace = __name__
subtask: gokart.TaskInstanceParameter[_DummySubTask] = gokart.TaskInstanceParameter(expected_type=_DummySubTask)
with self.assertRaises(TypeError):
_DummyPipelineB(subtask=_DummyInvalidSubClassTask()) # type: ignore
class ListTaskInstanceParameterTest(unittest.TestCase):
def setUp(self):
_DummyTask.clear_instance_cache()
def test_invalid_class(self):
self.assertRaises(TypeError, lambda: gokart.ListTaskInstanceParameter(expected_elements_type=1)) # type: ignore # not type instance
def test_list_params_with_correct_param_types(self):
class _DummyPipelineC(TaskOnKart[Any]):
task_namespace = __name__
subtask: gokart.ListTaskInstanceParameter[_DummySubTask] = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)
task = _DummyPipelineC(subtask=[_DummyCorrectSubClassTask()])
self.assertEqual(task.requires()['subtask'], (_DummyCorrectSubClassTask(),)) # type: ignore
def test_list_params_with_invalid_param_types(self):
class _DummyPipelineD(TaskOnKart[Any]):
task_namespace = __name__
subtask: gokart.ListTaskInstanceParameter[_DummySubTask] = gokart.ListTaskInstanceParameter(expected_elements_type=_DummySubTask)
with self.assertRaises(TypeError):
_DummyPipelineD(subtask=[_DummyInvalidSubClassTask(), _DummyCorrectSubClassTask()]) # type: ignore
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_task_on_kart.py
================================================
from __future__ import annotations
import os
import pathlib
import unittest
from datetime import datetime
from typing import Any, cast
from unittest.mock import Mock, patch
import luigi
import pandas as pd
from luigi.parameter import ParameterVisibility
from luigi.util import inherits
import gokart
from gokart.file_processor import XmlFileProcessor
from gokart.parameter import ListTaskInstanceParameter, TaskInstanceParameter
from gokart.target import ModelTarget, SingleFileTarget, TargetOnKart
from gokart.task import EmptyDumpError
class _DummyTask(gokart.TaskOnKart[Any]):
task_namespace = __name__
param: luigi.IntParameter = luigi.IntParameter(default=1)
list_param: luigi.ListParameter[tuple[str, ...]] = luigi.ListParameter(default=('a', 'b'))
bool_param: luigi.BoolParameter = luigi.BoolParameter()
def output(self):
return None
class _DummyTaskA(gokart.TaskOnKart[Any]):
task_namespace = __name__
def output(self):
return None
@inherits(_DummyTaskA)
class _DummyTaskB(gokart.TaskOnKart[Any]):
task_namespace = __name__
def output(self):
return None
def requires(self):
return self.clone(_DummyTaskA)
@inherits(_DummyTaskB)
class _DummyTaskC(gokart.TaskOnKart[Any]):
task_namespace = __name__
def output(self):
return None
def requires(self):
return self.clone(_DummyTaskB)
class _DummyTaskD(gokart.TaskOnKart[Any]):
task_namespace = __name__
class _DummyTaskWithoutLock(gokart.TaskOnKart[Any]):
task_namespace = __name__
def run(self):
pass
class _DummySubTaskWithPrivateParameter(gokart.TaskOnKart[Any]):
task_namespace = __name__
class _DummyTaskWithPrivateParameter(gokart.TaskOnKart[Any]):
task_namespace = __name__
int_param: luigi.IntParameter = luigi.IntParameter()
private_int_param: luigi.IntParameter = luigi.IntParameter(visibility=ParameterVisibility.PRIVATE)
task_param: TaskInstanceParameter[Any] = TaskInstanceParameter()
list_task_param: ListTaskInstanceParameter[Any] = ListTaskInstanceParameter()
class TaskTest(unittest.TestCase):
def setUp(self):
_DummyTask.clear_instance_cache()
_DummyTaskA.clear_instance_cache()
_DummyTaskB.clear_instance_cache()
_DummyTaskC.clear_instance_cache()
def test_complete_without_dependency(self):
task = _DummyTask()
self.assertTrue(task.complete(), msg='_DummyTask does not have any output files, so this always must be completed.')
def test_complete_with_rerun_flag(self):
task = _DummyTask(rerun=True)
self.assertFalse(task.complete(), msg='"rerun" flag force tasks rerun once.')
self.assertTrue(task.complete(), msg='"rerun" flag should be changed.')
def test_complete_with_uncompleted_input(self):
uncompleted_target = Mock(spec=TargetOnKart)
uncompleted_target.exists.return_value = False
# depends on an uncompleted target.
task = _DummyTask()
task.input = Mock(return_value=uncompleted_target) # type: ignore
self.assertTrue(task.complete(), msg='task does not care input targets.')
# make a task check its inputs.
task.strict_check = True
self.assertFalse(task.complete())
def test_complete_with_modified_input(self):
input_target = Mock(spec=TargetOnKart)
input_target.exists.return_value = True
input_target.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0)
output_target = Mock(spec=TargetOnKart)
output_target.exists.return_value = True
output_target.last_modification_time.return_value = datetime(2018, 1, 1, 9, 0, 0)
# depends on an uncompleted target.
task = _DummyTask()
task.modification_time_check = False
task.input = Mock(return_value=input_target) # type: ignore
task.output = Mock(return_value=output_target) # type: ignore
self.assertTrue(task.complete(), msg='task does not care modified time')
# make a task check its inputs.
task.modification_time_check = True
self.assertFalse(task.complete())
def test_complete_when_modification_time_equals_output(self):
"""Test the case that modification time of input equals that of output.
The case is occurred when input and output targets are same.
"""
input_target = Mock(spec=TargetOnKart)
input_target.exists.return_value = True
input_target.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0)
output_target = Mock(spec=TargetOnKart)
output_target.exists.return_value = True
output_target.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0)
task = _DummyTask()
task.modification_time_check = True
task.input = Mock(return_value=input_target) # type: ignore
task.output = Mock(return_value=output_target) # type: ignore
self.assertTrue(task.complete())
def test_complete_when_input_and_output_equal(self):
target1 = Mock(spec=TargetOnKart)
target1.exists.return_value = True
target1.path.return_value = 'path1.pkl'
target1.last_modification_time.return_value = datetime(2018, 1, 1, 10, 0, 0)
target2 = Mock(spec=TargetOnKart)
target2.exists.return_value = True
target2.path.return_value = 'path2.pkl'
target2.last_modification_time.return_value = datetime(2018, 1, 1, 9, 0, 0)
target3 = Mock(spec=TargetOnKart)
target3.exists.return_value = True
target3.path.return_value = 'path3.pkl'
target3.last_modification_time.return_value = datetime(2018, 1, 1, 9, 0, 0)
task = _DummyTask()
task.modification_time_check = True
task.input = Mock(return_value=[target1, target2]) # type: ignore
task.output = Mock(return_value=[target1, target2]) # type: ignore
self.assertTrue(task.complete())
task.input = Mock(return_value=[target1, target2]) # type: ignore
task.output = Mock(return_value=[target2, target3]) # type: ignore
self.assertFalse(task.complete())
def test_default_target(self):
task = _DummyTaskD()
default_target = task.output()
self.assertIsInstance(default_target, SingleFileTarget)
self.assertEqual(f'_DummyTaskD_{task.task_unique_id}.pkl', pathlib.Path(default_target._target.path).name) # type: ignore
def test_clone_with_special_params(self):
class _DummyTaskRerun(gokart.TaskOnKart[Any]):
a: luigi.BoolParameter = luigi.BoolParameter(default=False)
task = _DummyTaskRerun(a=True, rerun=True)
cloned = task.clone(_DummyTaskRerun)
cloned_with_explicit_rerun = task.clone(_DummyTaskRerun, rerun=True)
self.assertTrue(cloned.a)
self.assertFalse(cloned.rerun) # do not clone rerun
self.assertTrue(cloned_with_explicit_rerun.a)
self.assertTrue(cloned_with_explicit_rerun.rerun)
def test_default_large_dataframe_target(self):
task = _DummyTaskD()
default_large_dataframe_target = task.make_large_data_frame_target()
self.assertIsInstance(default_large_dataframe_target, ModelTarget)
target = cast(ModelTarget, default_large_dataframe_target)
self.assertEqual(f'_DummyTaskD_{task.task_unique_id}.zip', pathlib.Path(target._zip_client.path).name)
def test_make_target(self):
task = _DummyTask()
target = task.make_target('test.txt')
self.assertIsInstance(target, SingleFileTarget)
def test_make_target_without_id(self):
path = _DummyTask().make_target('test.txt', use_unique_id=False).path()
self.assertEqual(path, os.path.join(_DummyTask().workspace_directory, 'test.txt'))
def test_make_target_with_processor(self):
task = _DummyTask()
processor = XmlFileProcessor()
target = task.make_target('test.dummy', processor=processor)
self.assertIsInstance(target, SingleFileTarget)
target = cast(SingleFileTarget, target)
self.assertEqual(target._processor, processor)
def test_get_own_code(self):
task = _DummyTask()
task_scripts = 'def output(self):\nreturn None\n'
self.assertEqual(task.get_own_code().replace(' ', ''), task_scripts.replace(' ', ''))
def test_make_unique_id_with_own_code(self):
class _MyDummyTaskA(gokart.TaskOnKart[str]):
_visible_in_registry = False
def run(self):
self.dump('Hello, world!')
task_unique_id = _MyDummyTaskA(serialized_task_definition_check=False).make_unique_id()
task_with_code_unique_id = _MyDummyTaskA(serialized_task_definition_check=True).make_unique_id()
self.assertNotEqual(task_unique_id, task_with_code_unique_id)
class _MyDummyTaskA(gokart.TaskOnKart[str]): # type: ignore
_visible_in_registry = False
def run(self):
modified_code = 'modify!!'
self.dump(modified_code)
task_modified_unique_id = _MyDummyTaskA(serialized_task_definition_check=False).make_unique_id()
task_modified_with_code_unique_id = _MyDummyTaskA(serialized_task_definition_check=True).make_unique_id()
self.assertEqual(task_modified_unique_id, task_unique_id)
self.assertNotEqual(task_modified_with_code_unique_id, task_with_code_unique_id)
def test_compare_targets_of_different_tasks(self):
path1 = _DummyTask(param=1).make_target('test.txt').path()
path2 = _DummyTask(param=2).make_target('test.txt').path()
self.assertNotEqual(path1, path2, msg='different tasks must generate different targets.')
def test_make_model_target(self):
task = _DummyTask()
target = task.make_model_target('test.zip', save_function=Mock(), load_function=Mock())
self.assertIsInstance(target, ModelTarget)
def test_load_with_single_target(self):
task = _DummyTask()
target = Mock(spec=TargetOnKart)
target.load.return_value = 1
task.input = Mock(return_value=target) # type: ignore
data = task.load()
target.load.assert_called_once()
self.assertEqual(data, 1)
def test_load_with_single_dict_target(self):
task = _DummyTask()
target = Mock(spec=TargetOnKart)
target.load.return_value = 1
task.input = Mock(return_value={'target_key': target}) # type: ignore
data = task.load()
target.load.assert_called_once()
self.assertEqual(data, {'target_key': 1})
def test_load_with_keyword(self):
task = _DummyTask()
target = Mock(spec=TargetOnKart)
target.load.return_value = 1
task.input = Mock(return_value={'target_key': target}) # type: ignore
data = task.load('target_key')
target.load.assert_called_once()
self.assertEqual(data, 1)
def test_load_tuple(self):
task = _DummyTask()
target1 = Mock(spec=TargetOnKart)
target1.load.return_value = 1
target2 = Mock(spec=TargetOnKart)
target2.load.return_value = 2
task.input = Mock(return_value=(target1, target2)) # type: ignore
data = task.load()
target1.load.assert_called_once()
target2.load.assert_called_once()
self.assertEqual(data[0], 1)
self.assertEqual(data[1], 2)
def test_load_dictionary_at_once(self):
task = _DummyTask()
target1 = Mock(spec=TargetOnKart)
target1.load.return_value = 1
target2 = Mock(spec=TargetOnKart)
target2.load.return_value = 2
task.input = Mock(return_value={'target_key_1': target1, 'target_key_2': target2}) # type: ignore
data = task.load()
target1.load.assert_called_once()
target2.load.assert_called_once()
self.assertEqual(data['target_key_1'], 1)
self.assertEqual(data['target_key_2'], 2)
def test_load_with_task_on_kart(self):
task = _DummyTask()
task2 = Mock(spec=gokart.TaskOnKart)
task2.make_unique_id.return_value = 'task2'
task2_output = Mock(spec=TargetOnKart)
task2.output.return_value = task2_output
task2_output.load.return_value = 1
# task2 should be in requires' return values
task.requires = lambda: {'task2': task2} # type: ignore
actual = task.load(task2)
self.assertEqual(actual, 1)
def test_load_with_task_on_kart_should_fail_when_task_on_kart_is_not_in_requires(self):
"""
if load args is not in requires, it should raise an error.
"""
task = _DummyTask()
task2 = Mock(spec=gokart.TaskOnKart)
task2_output = Mock(spec=TargetOnKart)
task2.output.return_value = task2_output
task2_output.load.return_value = 1
with self.assertRaises(AssertionError):
task.load(task2)
def test_load_with_task_on_kart_list(self):
task = _DummyTask()
task2 = Mock(spec=gokart.TaskOnKart)
task2.make_unique_id.return_value = 'task2'
task2_output = Mock(spec=TargetOnKart)
task2.output.return_value = task2_output
task2_output.load.return_value = 1
task3 = Mock(spec=gokart.TaskOnKart)
task3.make_unique_id.return_value = 'task3'
task3_output = Mock(spec=TargetOnKart)
task3.output.return_value = task3_output
task3_output.load.return_value = 2
# task2 should be in requires' return values
task.requires = lambda: {'tasks': [task2, task3]} # type: ignore
load_args: list[gokart.TaskOnKart[int]] = [task2, task3]
actual = task.load(load_args)
self.assertEqual(actual, [1, 2])
def test_load_generator_with_single_target(self):
task = _DummyTask()
target = Mock(spec=TargetOnKart)
target.load.return_value = [1, 2]
task.input = Mock(return_value=target) # type: ignore
data = [x for x in task.load_generator()]
self.assertEqual(data, [[1, 2]])
def test_load_generator_with_keyword(self):
task = _DummyTask()
target = Mock(spec=TargetOnKart)
target.load.return_value = [1, 2]
task.input = Mock(return_value={'target_key': target}) # type: ignore
data = [x for x in task.load_generator('target_key')]
self.assertEqual(data, [[1, 2]])
def test_load_generator_with_list_task_on_kart(self):
task = _DummyTask()
task2 = Mock(spec=gokart.TaskOnKart)
task2.make_unique_id.return_value = 'task2'
task2_output = Mock(spec=TargetOnKart)
task2.output.return_value = task2_output
task2_output.load.return_value = 1
task3 = Mock(spec=gokart.TaskOnKart)
task3.make_unique_id.return_value = 'task3'
task3_output = Mock(spec=TargetOnKart)
task3.output.return_value = task3_output
task3_output.load.return_value = 2
# task2 should be in requires' return values
task.requires = lambda: {'tasks': [task2, task3]} # type: ignore
load_args: list[gokart.TaskOnKart[int]] = [task2, task3]
actual = [x for x in task.load_generator(load_args)]
self.assertEqual(actual, [1, 2])
def test_dump(self):
task = _DummyTask()
target = Mock(spec=TargetOnKart)
task.output = Mock(return_value=target) # type: ignore
task.dump(1)
target.dump.assert_called_once()
def test_fail_on_empty_dump(self):
# do not fail
task = _DummyTask(fail_on_empty_dump=False)
target = Mock(spec=TargetOnKart)
task.output = Mock(return_value=target) # type: ignore
task.dump(pd.DataFrame())
target.dump.assert_called_once()
# fail
task = _DummyTask(fail_on_empty_dump=True)
self.assertRaises(EmptyDumpError, lambda: task.dump(pd.DataFrame()))
@patch('luigi.configuration.get_config')
def test_add_configuration(self, mock_config: Mock) -> None:
mock_config.return_value = {'_DummyTask': {'list_param': '["c", "d"]', 'param': '3', 'bool_param': 'True'}}
kwargs: dict[str, Any] = dict()
_DummyTask._add_configuration(kwargs, '_DummyTask')
self.assertEqual(3, kwargs['param'])
self.assertEqual(['c', 'd'], list(kwargs['list_param']))
self.assertEqual(True, kwargs['bool_param'])
@patch('luigi.cmdline_parser.CmdlineParser.get_instance')
def test_add_cofigureation_evaluation_order(self, mock_cmdline: Mock) -> None:
"""
in case TaskOnKart._add_configuration will break evaluation order
@see https://luigi.readthedocs.io/en/stable/parameters.html#parameter-resolution-order
"""
class DummyTaskAddConfiguration(gokart.TaskOnKart[Any]):
aa = luigi.IntParameter()
luigi.configuration.get_config().set('DummyTaskAddConfiguration', 'aa', '3')
mock_cmdline.return_value = luigi.cmdline_parser.CmdlineParser(['DummyTaskAddConfiguration'])
self.assertEqual(DummyTaskAddConfiguration().aa, 3)
mock_cmdline.return_value = luigi.cmdline_parser.CmdlineParser(['DummyTaskAddConfiguration', '--DummyTaskAddConfiguration-aa', '2'])
self.assertEqual(DummyTaskAddConfiguration().aa, 2)
def test_use_rerun_with_inherits(self):
# All tasks are completed.
task_c = _DummyTaskC()
self.assertTrue(task_c.complete())
self.assertTrue(task_c.requires().complete()) # This is an instance of TaskB.
self.assertTrue(task_c.requires().requires().complete()) # This is an instance of TaskA.
luigi.configuration.get_config().set(f'{__name__}._DummyTaskB', 'rerun', 'True')
task_c = _DummyTaskC()
self.assertTrue(task_c.complete())
self.assertFalse(task_c.requires().complete()) # This is an instance of _DummyTaskB.
self.assertTrue(task_c.requires().requires().complete()) # This is an instance of _DummyTaskA.
# All tasks are not completed, because _DummyTaskC.rerun = True.
task_c = _DummyTaskC(rerun=True)
self.assertFalse(task_c.complete())
self.assertTrue(task_c.requires().complete()) # This is an instance of _DummyTaskB.
self.assertTrue(task_c.requires().requires().complete()) # This is an instance of _DummyTaskA.
def test_significant_flag(self) -> None:
def _make_task(significant: bool, has_required_task: bool) -> gokart.TaskOnKart[Any]:
class _MyDummyTaskA(gokart.TaskOnKart[Any]):
task_namespace = f'{__name__}_{significant}_{has_required_task}'
class _MyDummyTaskB(gokart.TaskOnKart[Any]):
task_namespace = f'{__name__}_{significant}_{has_required_task}'
def requires(self):
if has_required_task:
return _MyDummyTaskA(significant=significant)
return
return _MyDummyTaskB()
x_task = _make_task(significant=True, has_required_task=True)
y_task = _make_task(significant=False, has_required_task=True)
z_task = _make_task(significant=False, has_required_task=False)
self.assertNotEqual(x_task.make_unique_id(), y_task.make_unique_id())
self.assertEqual(y_task.make_unique_id(), z_task.make_unique_id())
def test_default_requires(self):
class _WithoutTaskInstanceParameter(gokart.TaskOnKart[Any]):
task_namespace = __name__
class _WithTaskInstanceParameter(gokart.TaskOnKart[Any]):
task_namespace = __name__
a_task: gokart.TaskInstanceParameter[Any] = gokart.TaskInstanceParameter()
without_task = _WithoutTaskInstanceParameter()
self.assertListEqual(without_task.requires(), []) # type: ignore
with_task = _WithTaskInstanceParameter(a_task=without_task)
self.assertEqual(with_task.requires()['a_task'], without_task) # type: ignore
def test_repr(self):
task = _DummyTaskWithPrivateParameter(
int_param=1,
private_int_param=1,
task_param=_DummySubTaskWithPrivateParameter(),
list_task_param=[_DummySubTaskWithPrivateParameter(), _DummySubTaskWithPrivateParameter()],
)
task_id = task.make_unique_id()
sub_task_id = _DummySubTaskWithPrivateParameter().make_unique_id()
expected = (
f'{__name__}._DummyTaskWithPrivateParameter[{task_id}](int_param=1, private_int_param=1, task_param={__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), '
f'list_task_param=[{__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), {__name__}._DummySubTaskWithPrivateParameter({sub_task_id})])'
) # noqa:E501
self.assertEqual(expected, repr(task))
def test_str(self):
task = _DummyTaskWithPrivateParameter(
int_param=1,
private_int_param=1,
task_param=_DummySubTaskWithPrivateParameter(),
list_task_param=[_DummySubTaskWithPrivateParameter(), _DummySubTaskWithPrivateParameter()],
)
task_id = task.make_unique_id()
sub_task_id = _DummySubTaskWithPrivateParameter().make_unique_id()
expected = (
f'{__name__}._DummyTaskWithPrivateParameter[{task_id}](int_param=1, task_param={__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), '
f'list_task_param=[{__name__}._DummySubTaskWithPrivateParameter({sub_task_id}), {__name__}._DummySubTaskWithPrivateParameter({sub_task_id})])'
)
self.assertEqual(expected, str(task))
def test_is_task_on_kart(self):
self.assertEqual(True, gokart.TaskOnKart.is_task_on_kart(gokart.TaskOnKart()))
self.assertEqual(False, gokart.TaskOnKart.is_task_on_kart(1))
self.assertEqual(False, gokart.TaskOnKart.is_task_on_kart(list()))
self.assertEqual(True, gokart.TaskOnKart.is_task_on_kart((gokart.TaskOnKart(), gokart.TaskOnKart())))
def test_serialize_and_deserialize_default_values(self):
task: gokart.TaskOnKart[Any] = gokart.TaskOnKart()
deserialized: gokart.TaskOnKart[Any] = luigi.task_register.load_task(None, task.get_task_family(), task.to_str_params())
self.assertDictEqual(task.to_str_params(), deserialized.to_str_params())
def test_to_str_params_changes_on_values_and_flags(self):
class _DummyTaskWithParams(gokart.TaskOnKart[Any]):
task_namespace = __name__
param: luigi.Parameter = luigi.Parameter()
t1 = _DummyTaskWithParams(param='a')
self.assertEqual(t1.to_str_params(), t1.to_str_params()) # cache
self.assertEqual(t1.to_str_params(), _DummyTaskWithParams(param='a').to_str_params()) # same value
self.assertNotEqual(t1.to_str_params(), _DummyTaskWithParams(param='b').to_str_params()) # different value
self.assertNotEqual(t1.to_str_params(), t1.to_str_params(only_significant=True))
def test_should_lock_run_when_set(self):
class _DummyTaskWithLock(gokart.TaskOnKart[str]):
def run(self):
self.dump('hello')
task = _DummyTaskWithLock(redis_host='host', redis_port=123, redis_timeout=180, should_lock_run=True)
self.assertEqual(task.run.__wrapped__.__name__, 'run') # type: ignore
def test_should_fail_lock_run_when_host_unset(self):
with self.assertRaises(AssertionError):
gokart.TaskOnKart(redis_port=123, redis_timeout=180, should_lock_run=True)
def test_should_fail_lock_run_when_port_unset(self):
with self.assertRaises(AssertionError):
gokart.TaskOnKart(redis_host='host', redis_timeout=180, should_lock_run=True)
class _DummyTaskWithNonCompleted(gokart.TaskOnKart[Any]):
def dump(self, _obj: Any, _target: Any = None, _custom_labels: Any = None) -> None:
# overrive dump() to do nothing.
pass
def run(self):
self.dump('hello')
def complete(self):
return False
class _DummyTaskWithCompleted(gokart.TaskOnKart[Any]):
def dump(self, obj: Any, _target: Any = None, custom_labels: Any = None) -> None:
# overrive dump() to do nothing.
pass
def run(self):
self.dump('hello')
def complete(self):
return True
class TestCompleteCheckAtRun(unittest.TestCase):
def test_run_when_complete_check_at_run_is_false_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=False)
task.dump = Mock() # type: ignore
task.run()
# since run() is called, dump() should be called.
task.dump.assert_called_once()
def test_run_when_complete_check_at_run_is_false_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=False)
task.dump = Mock() # type: ignore
task.run()
# even task is completed, since run() is called, dump() should be called.
task.dump.assert_called_once()
def test_run_when_complete_check_at_run_is_true_and_task_is_not_completed(self):
task = _DummyTaskWithNonCompleted(complete_check_at_run=True)
task.dump = Mock() # type: ignore
task.run()
# since task is not completed, when run() is called, dump() should be called.
task.dump.assert_called_once()
def test_run_when_complete_check_at_run_is_true_and_task_is_completed(self):
task = _DummyTaskWithCompleted(complete_check_at_run=True)
task.dump = Mock() # type: ignore
task.run()
# since task is completed, even when run() is called, dump() should not be called.
task.dump.assert_not_called()
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/test_utils.py
================================================
import unittest
from typing import TYPE_CHECKING
import pandas as pd
import pytest
from gokart.task import TaskOnKart
from gokart.utils import flatten, get_dataframe_type_from_task, map_flattenable_items
if TYPE_CHECKING:
import polars as pl
try:
import polars as pl
HAS_POLARS = True
except ImportError:
HAS_POLARS = False
class TestFlatten(unittest.TestCase):
def test_flatten_dict(self):
self.assertEqual(flatten({'a': 'foo', 'b': 'bar'}), ['foo', 'bar'])
def test_flatten_list(self):
self.assertEqual(flatten(['foo', ['bar', 'troll']]), ['foo', 'bar', 'troll'])
def test_flatten_str(self):
self.assertEqual(flatten('foo'), ['foo'])
def test_flatten_int(self):
self.assertEqual(flatten(42), [42])
def test_flatten_none(self):
self.assertEqual(flatten(None), [])
class TestMapFlatten(unittest.TestCase):
def test_map_flattenable_items(self):
self.assertEqual(map_flattenable_items(lambda x: str(x), {'a': 1, 'b': 2}), {'a': '1', 'b': '2'})
self.assertEqual(
map_flattenable_items(lambda x: str(x), (1, 2, 3, (4, 5, (6, 7, {'a': (8, 9, 0)})))),
('1', '2', '3', ('4', '5', ('6', '7', {'a': ('8', '9', '0')}))),
)
self.assertEqual(
map_flattenable_items(
lambda x: str(x),
{'a': [1, 2, 3, '4'], 'b': {'c': True, 'd': {'e': 5}}},
),
{'a': ['1', '2', '3', '4'], 'b': {'c': 'True', 'd': {'e': '5'}}},
)
class TestGetDataFrameTypeFromTask(unittest.TestCase):
"""Tests for get_dataframe_type_from_task function."""
def test_pandas_dataframe_from_instance(self):
"""Test detecting pandas DataFrame from task instance."""
class _PandasTaskInstance(TaskOnKart[pd.DataFrame]):
pass
task = _PandasTaskInstance()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
def test_pandas_dataframe_from_class(self):
"""Test detecting pandas DataFrame from task class."""
class _PandasTaskClass(TaskOnKart[pd.DataFrame]):
pass
self.assertEqual(get_dataframe_type_from_task(_PandasTaskClass), 'pandas')
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_dataframe_from_instance(self):
"""Test detecting polars DataFrame from task instance."""
class _PolarsTaskInstance(TaskOnKart[pl.DataFrame]):
pass
task = _PolarsTaskInstance()
self.assertEqual(get_dataframe_type_from_task(task), 'polars')
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_dataframe_from_class(self):
"""Test detecting polars DataFrame from task class."""
class _PolarsTaskClass(TaskOnKart[pl.DataFrame]):
pass
self.assertEqual(get_dataframe_type_from_task(_PolarsTaskClass), 'polars')
def test_no_type_parameter_defaults_to_pandas(self):
"""Test that tasks without type parameter default to pandas."""
# Create a class without __orig_bases__ by not using type parameters
class PlainTask:
pass
task = PlainTask()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
def test_non_taskonkart_class_defaults_to_pandas(self):
"""Test that non-TaskOnKart classes default to pandas."""
class RegularClass:
pass
task = RegularClass()
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
def test_taskonkart_with_non_dataframe_type(self):
"""Test TaskOnKart with non-DataFrame type parameter defaults to pandas."""
class _StringTask(TaskOnKart[str]):
pass
task = _StringTask()
# Should default to pandas since str module is not 'pandas' or 'polars'
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
def test_nested_inheritance_pandas(self):
"""Test that nested inheritance without direct type parameter defaults to pandas."""
class _BasePandasTask(TaskOnKart[pd.DataFrame]):
pass
class _DerivedPandasTask(_BasePandasTask):
pass
task = _DerivedPandasTask()
# _DerivedPandasTask doesn't have its own __orig_bases__ with type parameter,
# so it defaults to 'pandas'
self.assertEqual(get_dataframe_type_from_task(task), 'pandas')
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_nested_inheritance_polars(self):
"""Test detecting polars DataFrame type through nested inheritance."""
class _BasePolarsTask(TaskOnKart[pl.DataFrame]):
pass
class _DerivedPolarsTask(_BasePolarsTask):
pass
task = _DerivedPolarsTask()
# Function should detect 'polars' through the inheritance chain
self.assertEqual(get_dataframe_type_from_task(task), 'polars')
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_lazyframe_from_instance(self):
class _LazyTaskInstance(TaskOnKart[pl.LazyFrame]):
pass
task = _LazyTaskInstance()
self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy')
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_polars_lazyframe_from_class(self):
class _LazyTaskClass(TaskOnKart[pl.LazyFrame]):
pass
self.assertEqual(get_dataframe_type_from_task(_LazyTaskClass), 'polars-lazy')
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_nested_inheritance_polars_lazyframe(self):
class _BaseLazyTask(TaskOnKart[pl.LazyFrame]):
pass
class _DerivedLazyTask(_BaseLazyTask):
pass
task = _DerivedLazyTask()
self.assertEqual(get_dataframe_type_from_task(task), 'polars-lazy')
@pytest.mark.skipif(not HAS_POLARS, reason='polars not installed')
def test_nested_inheritance_polars_with_mixin(self):
"""Derived class with multiple bases should still detect polars through MRO."""
class _Mixin:
pass
class _BasePolarsTaskWithMixin(TaskOnKart[pl.DataFrame]):
pass
# Multiple inheritance gives _DerivedTask its own __orig_bases__,
# which shadows the parent's and doesn't contain TaskOnKart[...].
class _DerivedTaskWithMixin(_BasePolarsTaskWithMixin, _Mixin):
pass
task = _DerivedTaskWithMixin()
self.assertEqual(get_dataframe_type_from_task(task), 'polars')
================================================
FILE: test/test_worker.py
================================================
import uuid
from unittest.mock import Mock
import luigi
import luigi.worker
import pytest
from luigi import scheduler
import gokart
from gokart.worker import Worker, gokart_worker
class _DummyTask(gokart.TaskOnKart[str]):
task_namespace = __name__
random_id: luigi.StrParameter = luigi.StrParameter()
def _run(self): ...
def run(self):
self._run()
self.dump('test')
class TestWorkerRun:
def test_run(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Check run is called when the task is not completed"""
sch = scheduler.Scheduler()
worker = Worker(scheduler=sch)
task = _DummyTask(random_id=uuid.uuid4().hex)
mock_run = Mock()
monkeypatch.setattr(task, '_run', mock_run)
with worker:
assert worker.add(task)
assert worker.run()
mock_run.assert_called_once()
class _DummyTaskToCheckSkip(gokart.TaskOnKart[None]):
task_namespace = __name__
def _run(self): ...
def run(self):
self._run()
self.dump(None)
def complete(self) -> bool:
return False
class TestWorkerSkipIfCompletedPreRun:
@pytest.mark.parametrize(
'task_completion_check_at_run,is_completed,expect_skipped',
[
pytest.param(True, True, True, id='skipped when completed and task_completion_check_at_run is True'),
pytest.param(True, False, False, id='not skipped when not completed and task_completion_check_at_run is True'),
pytest.param(False, True, False, id='not skipped when completed and task_completion_check_at_run is False'),
pytest.param(False, False, False, id='not skipped when not completed and task_completion_check_at_run is False'),
],
)
def test_skip_task(self, monkeypatch: pytest.MonkeyPatch, task_completion_check_at_run: bool, is_completed: bool, expect_skipped: bool) -> None:
sch = scheduler.Scheduler()
worker = Worker(scheduler=sch, config=gokart_worker(task_completion_check_at_run=task_completion_check_at_run))
mock_complete = Mock(return_value=is_completed)
# NOTE: set `complete_check_at_run=False` to avoid using deprecated skip logic.
task = _DummyTaskToCheckSkip(complete_check_at_run=False)
mock_run = Mock()
monkeypatch.setattr(task, '_run', mock_run)
with worker:
assert worker.add(task)
# NOTE: mock `complete` after `add` because `add` calls `complete`
# to check if the task is already completed.
monkeypatch.setattr(task, 'complete', mock_complete)
assert worker.run()
if expect_skipped:
mock_run.assert_not_called()
else:
mock_run.assert_called_once()
class TestWorkerCheckCompleteValue:
def test_does_not_raise_for_boolean_values(self) -> None:
worker = Worker(scheduler=scheduler.Scheduler())
worker._check_complete_value(True)
worker._check_complete_value(False)
def test_raises_async_completion_exception_for_traceback_wrapper(self) -> None:
# NOTE: When Task.complete() raises in an async check, the exception is wrapped
# in TracebackWrapper. This branch must raise AsyncCompletionException.
worker = Worker(scheduler=scheduler.Scheduler())
wrapped = luigi.worker.TracebackWrapper(trace='dummy traceback')
with pytest.raises(luigi.worker.AsyncCompletionException):
worker._check_complete_value(wrapped)
def test_raises_exception_for_non_boolean_value(self) -> None:
# NOTE: Pass a non-bool value to verify the runtime guard against a misimplemented
# Task.complete() returning a non-boolean. The type ignore is intentional.
worker = Worker(scheduler=scheduler.Scheduler())
with pytest.raises(Exception, match='Return value of Task.complete'):
worker._check_complete_value('not a bool') # type: ignore[arg-type]
================================================
FILE: test/test_zoned_date_second_parameter.py
================================================
import datetime
import unittest
from luigi.cmdline_parser import CmdlineParser
from gokart import TaskOnKart, ZonedDateSecondParameter
class ZonedDateSecondParameterTaskWithoutDefault(TaskOnKart[datetime.datetime]):
task_namespace = __name__
dt: ZonedDateSecondParameter = ZonedDateSecondParameter()
def run(self):
self.dump(self.dt)
class ZonedDateSecondParameterTaskWithDefault(TaskOnKart[datetime.datetime]):
task_namespace = __name__
dt: ZonedDateSecondParameter = ZonedDateSecondParameter(
default=datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9)))
)
def run(self):
self.dump(self.dt)
class ZonedDateSecondParameterTest(unittest.TestCase):
def setUp(self):
self.default_datetime = datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9)))
self.default_datetime_str = '2025-02-21T12:00:00+09:00'
def test_default(self):
with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault']) as cp:
assert cp.get_task_obj().dt == self.default_datetime
def test_parse_param_with_tz_suffix(self):
with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00+09:00']) as cp:
assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=9)))
def test_parse_param_with_Z_suffix(self):
with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithDefault', '--dt', '2024-01-20T11:00:00Z']) as cp:
assert cp.get_task_obj().dt == datetime.datetime(2024, 1, 20, 11, 0, 0, tzinfo=datetime.timezone(datetime.timedelta(hours=0)))
def test_parse_param_without_timezone_input(self):
with CmdlineParser.global_instance([f'{__name__}.ZonedDateSecondParameterTaskWithoutDefault', '--dt', '2025-02-21T12:00:00']) as cp:
assert cp.get_task_obj().dt == datetime.datetime(2025, 2, 21, 12, 0, 0, tzinfo=None)
def test_parse_method(self):
actual = ZonedDateSecondParameter().parse(self.default_datetime_str)
expected = self.default_datetime
self.assertEqual(actual, expected)
def test_serialize_task(self):
task = ZonedDateSecondParameterTaskWithoutDefault(dt=self.default_datetime)
actual = str(task)
expected = f'(dt={self.default_datetime_str})'
self.assertTrue(actual.endswith(expected))
if __name__ == '__main__':
unittest.main()
================================================
FILE: test/testing/__init__.py
================================================
================================================
FILE: test/testing/test_pandas_assert.py
================================================
import unittest
import pandas as pd
import gokart
class TestPandasAssert(unittest.TestCase):
def test_assert_frame_contents_equal(self):
expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2])
gokart.testing.assert_frame_contents_equal(resulted, expected)
def test_assert_frame_contents_equal_with_small_error(self):
expected = pd.DataFrame(data=dict(f1=[1.0001, 2.0001, 3.0001], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2.0002, 1.0002, 3.0002], f3=[222, 111, 333]), index=[1, 0, 2])
gokart.testing.assert_frame_contents_equal(resulted, expected, atol=1e-1)
def test_assert_frame_contents_equal_with_duplicated_columns(self):
expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
expected.columns = ['f1', 'f1', 'f2']
resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2])
resulted.columns = ['f2', 'f1', 'f1']
with self.assertRaises(AssertionError):
gokart.testing.assert_frame_contents_equal(resulted, expected)
def test_assert_frame_contents_equal_with_duplicated_indexes(self):
expected = pd.DataFrame(data=dict(f1=[1, 2, 3], f3=[111, 222, 333], f2=[4, 5, 6]), index=[0, 1, 2])
expected.index = [0, 1, 1]
resulted = pd.DataFrame(data=dict(f2=[5, 4, 6], f1=[2, 1, 3], f3=[222, 111, 333]), index=[1, 0, 2])
expected.index = [1, 0, 1]
with self.assertRaises(AssertionError):
gokart.testing.assert_frame_contents_equal(resulted, expected)
================================================
FILE: test/tree/__init__.py
================================================
================================================
FILE: test/tree/test_task_info.py
================================================
from __future__ import annotations
import unittest
from typing import Any
from unittest.mock import patch
import luigi
import luigi.mock
from luigi.mock import MockFileSystem, MockTarget
import gokart
from gokart.tree.task_info import dump_task_info_table, dump_task_info_tree, make_task_info_as_tree_str, make_task_info_tree
class _SubTask(gokart.TaskOnKart[str]):
task_namespace = __name__
param: luigi.IntParameter = luigi.IntParameter()
def output(self):
return self.make_target('sub_task.txt')
def run(self):
self.dump(f'task uid = {self.make_unique_id()}')
class _Task(gokart.TaskOnKart[str]):
task_namespace = __name__
param: luigi.IntParameter = luigi.IntParameter(default=10)
sub: gokart.TaskInstanceParameter[_SubTask] = gokart.TaskInstanceParameter(default=_SubTask(param=20))
def requires(self):
return self.sub
def output(self):
return self.make_target('task.txt')
def run(self):
self.dump(f'task uid = {self.make_unique_id()}')
class _DoubleLoadSubTask(gokart.TaskOnKart[str]):
task_namespace = __name__
sub1: gokart.TaskInstanceParameter[gokart.TaskOnKart[Any]] = gokart.TaskInstanceParameter()
sub2: gokart.TaskInstanceParameter[gokart.TaskOnKart[Any]] = gokart.TaskInstanceParameter()
def output(self):
return self.make_target('sub_task.txt')
def run(self):
self.dump(f'task uid = {self.make_unique_id()}')
class TestInfo(unittest.TestCase):
def setUp(self) -> None:
MockFileSystem().clear()
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
def tearDown(self) -> None:
luigi.setup_logging.DaemonLogging._configured = False
luigi.setup_logging.InterfaceLogging._configured = False
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_pending(self):
task = _Task(param=1, sub=_SubTask(param=2))
# check before running
tree = make_task_info_as_tree_str(task)
expected = r"""
└─-\(PENDING\) _Task\[[a-z0-9]*\]
└─-\(PENDING\) _SubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_complete(self):
task = _Task(param=1, sub=_SubTask(param=2))
# check after sub task runs
gokart.build(task, reset_register=False)
tree = make_task_info_as_tree_str(task)
expected = r"""
└─-\(COMPLETE\) _Task\[[a-z0-9]*\]
└─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_abbreviation(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)
# check after sub task runs
gokart.build(task, reset_register=False)
tree = make_task_info_as_tree_str(task)
expected = r"""
└─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]
\|--\(COMPLETE\) _Task\[[a-z0-9]*\]
\| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]
└─-\(COMPLETE\) _Task\[[a-z0-9]*\]
└─- \.\.\.$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_not_compress(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)
# check after sub task runs
gokart.build(task, reset_register=False)
tree = make_task_info_as_tree_str(task, abbr=False)
expected = r"""
└─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]
\|--\(COMPLETE\) _Task\[[a-z0-9]*\]
\| └─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]
└─-\(COMPLETE\) _Task\[[a-z0-9]*\]
└─-\(COMPLETE\) _SubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_not_compress_ignore_task(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)
# check after sub task runs
gokart.build(task, reset_register=False)
tree = make_task_info_as_tree_str(task, abbr=False, ignore_task_names=['_Task'])
expected = r"""
└─-\(COMPLETE\) _DoubleLoadSubTask\[[a-z0-9]*\]$"""
self.assertRegex(tree, expected)
@patch('luigi.LocalTarget', new=lambda path, **kwargs: MockTarget(path, **kwargs))
def test_make_tree_info_with_cache(self):
task = _DoubleLoadSubTask(
sub1=_Task(param=1, sub=_SubTask(param=2)),
sub2=_Task(param=1, sub=_SubTask(param=2)),
)
# check child task_info is the same object
tree = make_task_info_tree(task)
self.assertTrue(tree.children_task_infos[0] is tree.children_task_infos[1])
class _TaskInfoExampleTaskA(gokart.TaskOnKart[Any]):
task_namespace = __name__
class _TaskInfoExampleTaskB(gokart.TaskOnKart[Any]):
task_namespace = __name__
class _TaskInfoExampleTaskC(gokart.TaskOnKart[str]):
task_namespace = __name__
def requires(self):
return dict(taskA=_TaskInfoExampleTaskA(), taskB=_TaskInfoExampleTaskB())
def run(self):
self.dump('DONE')
class TestTaskInfoTable(unittest.TestCase):
def test_dump_task_info_table(self):
with patch('gokart.target.SingleFileTarget.dump') as mock_obj:
self.dumped_data: Any = None
def _side_effect(obj, lock_at_dump):
self.dumped_data = obj
mock_obj.side_effect = _side_effect
dump_task_info_table(task=_TaskInfoExampleTaskC(), task_info_dump_path='path.csv', ignore_task_names=['_TaskInfoExampleTaskB'])
self.assertEqual(set(self.dumped_data['name']), {'_TaskInfoExampleTaskA', '_TaskInfoExampleTaskC'})
self.assertEqual(
set(self.dumped_data.columns), {'name', 'unique_id', 'output_paths', 'params', 'processing_time', 'is_complete', 'task_log', 'requires'}
)
class TestTaskInfoTree(unittest.TestCase):
def test_dump_task_info_tree(self):
with patch('gokart.target.SingleFileTarget.dump') as mock_obj:
self.dumped_data: Any = None
def _side_effect(obj, lock_at_dump):
self.dumped_data = obj
mock_obj.side_effect = _side_effect
dump_task_info_tree(task=_TaskInfoExampleTaskC(), task_info_dump_path='path.pkl', ignore_task_names=['_TaskInfoExampleTaskB'])
self.assertEqual(self.dumped_data.name, '_TaskInfoExampleTaskC')
self.assertEqual(self.dumped_data.children_task_infos[0].name, '_TaskInfoExampleTaskA')
self.assertEqual(self.dumped_data.requires.keys(), {'taskA', 'taskB'})
self.assertEqual(self.dumped_data.requires['taskA'].name, '_TaskInfoExampleTaskA')
self.assertEqual(self.dumped_data.requires['taskB'].name, '_TaskInfoExampleTaskB')
def test_dump_task_info_tree_with_invalid_path_extention(self):
with patch('gokart.target.SingleFileTarget.dump') as mock_obj:
self.dumped_data = None
def _side_effect(obj, lock_at_dump):
self.dumped_data = obj
mock_obj.side_effect = _side_effect
with self.assertRaises(AssertionError):
dump_task_info_tree(task=_TaskInfoExampleTaskC(), task_info_dump_path='path.csv', ignore_task_names=['_TaskInfoExampleTaskB'])
================================================
FILE: test/tree/test_task_info_formatter.py
================================================
import unittest
from typing import Any
import gokart
from gokart.tree.task_info_formatter import RequiredTask, _make_requires_info
class _RequiredTaskExampleTaskA(gokart.TaskOnKart[Any]):
task_namespace = __name__
class TestMakeRequiresInfo(unittest.TestCase):
def test_make_requires_info_with_task_on_kart(self):
requires = _RequiredTaskExampleTaskA()
resulted = _make_requires_info(requires=requires)
expected = RequiredTask(name=requires.__class__.__name__, unique_id=requires.make_unique_id())
self.assertEqual(resulted, expected)
def test_make_requires_info_with_list(self):
requires = [_RequiredTaskExampleTaskA()]
resulted = _make_requires_info(requires=requires)
expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in requires]
self.assertEqual(resulted, expected)
def test_make_requires_info_with_generator(self):
def _requires_gen():
return (_RequiredTaskExampleTaskA() for _ in range(2))
resulted = _make_requires_info(requires=_requires_gen())
expected = [RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for require in _requires_gen()]
self.assertEqual(resulted, expected)
def test_make_requires_info_with_dict(self):
requires = dict(taskA=_RequiredTaskExampleTaskA())
resulted = _make_requires_info(requires=requires)
expected = {key: RequiredTask(name=require.__class__.__name__, unique_id=require.make_unique_id()) for key, require in requires.items()}
self.assertEqual(resulted, expected)
def test_make_requires_info_with_invalid(self):
requires = [1, 2]
with self.assertRaises(TypeError):
_make_requires_info(requires=requires)
================================================
FILE: test/util.py
================================================
import os
import uuid
# TODO: use pytest.fixture to share this functionality with other tests
def _get_temporary_directory():
_uuid = str(uuid.uuid4())
return os.path.abspath(os.path.join(os.path.dirname(__name__), f'temporary-{_uuid}'))
================================================
FILE: tox.ini
================================================
[tox]
envlist = py{310,311,312,313,314},ruff,mypy
skipsdist = True
[testenv]
runner = uv-venv-lock-runner
dependency_groups = test
commands =
{envpython} -m pytest --cov=gokart --cov-report=xml -vv {posargs:}
[testenv:ruff]
dependency_groups = lint
commands =
ruff check {posargs:}
ruff format --check {posargs:}
[testenv:mypy]
dependency_groups = lint
commands = mypy gokart test {posargs:}